Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

pull/2364/head
oahzxl 2023-01-10 11:29:01 +08:00
commit e532679c95
763 changed files with 78573 additions and 6836 deletions

View File

@ -20,6 +20,8 @@ body:
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Optional: Affiliation**
Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
placeholder: |
A clear and concise description of what the bug is.
validations:

View File

@ -17,6 +17,7 @@ body:
**Expectation** What is your expected content about it?
**Screenshots** If applicable, add screenshots to help explain your problem.
**Suggestions** Tell us how we could improve the documentation.
**Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
placeholder: |
A clear and concise description of the issue.
validations:

View File

@ -22,6 +22,8 @@ body:
If applicable, add screenshots to help explain your problem.
**Suggest a potential alternative/fix**
Tell us how we could improve this project.
**Optional: Affiliation**
Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
placeholder: |
A clear and concise description of your idea.
validations:

View File

@ -13,6 +13,7 @@ body:
- Bumping a critical dependency's major version;
- A significant improvement in user-friendliness;
- Significant refactor;
- Optional: Affiliation/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation.
- ...
Please note this is not for feature request or bug template; such action could make us identify the issue wrongly and close it without doing anything.
@ -43,4 +44,4 @@ body:
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
Thanks for contributing 🎉!

View File

@ -1,9 +0,0 @@
addReviewers: true
addAssignees: author
numberOfReviewers: 1
reviewers:
- frankleeeee
- kurisusnowdeng

View File

@ -1,18 +0,0 @@
name: Assign Reviewers for Team
on:
pull_request:
types: [opened]
jobs:
assign_reviewer:
name: Assign Reviewer for PR
runs-on: ubuntu-latest
if: |
github.event.pull_request.draft == false && github.base_ref == 'main'
&& github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
&& toJson(github.event.pull_request.requested_reviewers) == '[]'
steps:
- uses: kentaro-m/auto-assign-action@v1.2.1
with:
configuration-path: '.github/reviewer_list.yml'

130
.github/workflows/auto_example_check.yml vendored Normal file
View File

@ -0,0 +1,130 @@
name: Test Example
on:
pull_request:
# any change in the examples folder will trigger check for the corresponding example.
paths:
- 'examples/**'
# run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00
schedule:
- cron: '0 16 * * 6'
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
detect-changed-example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
- name: Get all changed example files
id: changed-files
uses: tj-actions/changed-files@v35
# Using this can trigger action each time a PR is submitted.
with:
since_last_remote_commit: true
- name: setup matrix
id: setup-matrix
run: |
changedFileName=""
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
changedFileName="${file}:${changedFileName}"
done
echo "$changedFileName was changed"
res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName`
echo "All changed examples are $res"
if [ "$x" = "[]" ]; then
echo "anyChanged=false" >> $GITHUB_OUTPUT
echo "matrix=null" >> $GITHUB_OUTPUT
else
dirs=$( IFS=',' ; echo "${res[*]}" )
echo "anyChanged=true" >> $GITHUB_OUTPUT
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
fi
# If no file is changed, it will prompt an error and shows the matrix do not have value.
check-changed-example:
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
name: Test the changed example
needs: detect-changed-example
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
pip install -v .
- name: Test the example
run: |
example_dir=${{ matrix.directory }}
cd "${PWD}/examples/${example_dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1
# This is for all files' weekly check. Specifically, this job is to find all the directories.
matrix_preparation:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
name: Prepare matrix for weekly check
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: setup matrix
id: setup-matrix
run: |
res=`python .github/workflows/scripts/example_checks/check_example_weekly.py`
all_loc=$( IFS=',' ; echo "${res[*]}" )
echo "Found the examples: $all_loc"
echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT
weekly_check:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule'
name: Weekly check all examples
needs: matrix_preparation
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
timeout-minutes: 10
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
pip install -v .
- name: Traverse all files
run: |
example_dir=${{ matrix.diretory }}
echo "Testing ${example_dir} now"
cd "${PWD}/examples/${example_dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1

View File

@ -1,17 +1,45 @@
name: Build
on:
on:
pull_request:
types: [synchronize, labeled]
jobs:
build:
name: Build and Test Colossal-AI
detect:
name: Detect kernel-related file change
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs:
changedFiles: ${{ steps.find-changed-files.outputs.changedFiles }}
anyChanged: ${{ steps.find-changed-files.outputs.any_changed }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
- name: Find the changed files
id: find-changed-files
uses: tj-actions/changed-files@v35
with:
since_last_remote_commit: true
files: |
op_builder/**
colossalai/kernel/**
setup.py
- name: List changed files
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
echo "$file was changed"
done
build:
name: Build and Test Colossal-AI
needs: detect
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.11.0-11.3.0
@ -23,27 +51,38 @@ jobs:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install Colossal-AI
- name: Restore cache
if: needs.detect.outputs.anyChanged != 'true'
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
pip install -r requirements/requirements.txt
pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
- name: Install Colossal-AI
run: |
CUDA_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest tests
PYTHONPATH=$PWD pytest --cov=. --cov-report lcov tests
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: Store Cache
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/

View File

@ -2,7 +2,7 @@ name: Build on 8 GPUs
on:
schedule:
# run at 00:00 of every Sunday
# run at 00:00 of every Sunday
- cron: '0 0 * * *'
workflow_dispatch:
@ -30,13 +30,11 @@ jobs:
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install Colossal-AI
- name: Install Colossal-AI
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
pip install -r requirements/requirements.txt
pip install -v -e .
CUDA_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
@ -45,4 +43,3 @@ jobs:
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

View File

@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Install Colossal-AI
- name: Install Colossal-AI
run: |
pip install -r requirements/requirements.txt
pip install -v --no-cache-dir .

View File

@ -0,0 +1,63 @@
name: Manual Test Example
on:
workflow_dispatch:
inputs:
example_directory:
type: string
description: example directory, separated by space. For example, language/gpt, images/vit. Simply input language or simply gpt does not work.
required: true
jobs:
matrix_preparation:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
name: Check the examples user want
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Set up matrix
id: set-matrix
env:
check_dir: ${{ inputs.example_directory }}
run: |
res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir`
if [ res == "failure" ];then
exit -1
fi
dirs="[${check_dir}]"
echo "Testing examples in $dirs"
echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT
test_example:
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
name: Manually check example files
needs: manual_check_matrix_preparation
runs-on: [self-hosted, gpu]
strategy:
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
pip install -v .
- name: Test the example
run: |
dir=${{ matrix.directory }}
echo "Testing ${dir} now"
cd "${PWD}/examples/${dir}"
bash test_ci.sh
env:
NCCL_SHM_DISABLE: 1

View File

@ -20,7 +20,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v2
with:
python-version: '3.7.12'
python-version: '3.8.14'
- name: generate draft
id: generate_draft
run: |
@ -42,4 +42,3 @@ jobs:
body_path: ${{ steps.generate_draft.outputs.path }}
draft: True
prerelease: false

View File

@ -64,9 +64,21 @@ jobs:
- name: Copy scripts and checkout
run: |
cp -r ./.github/workflows/scripts/* ./
# link the cache diretories to current path
ln -s /github/home/conda_pkgs ./conda_pkgs
ln -s /github/home/pip_wheels ./pip_wheels
# set the conda package path
echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc
# set safe directory
git config --global --add safe.directory /__w/ColossalAI/ColossalAI
# check out
git checkout $git_ref
# get cub package for cuda 10.2
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
env:

View File

@ -18,23 +18,17 @@ jobs:
with:
fetch-depth: 0
- name: Build Docker
id: build
run: |
version=$(cat version.txt)
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker
tag=hpcaitech/colossalai:$version
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
echo "tag=${tag}" >> $GITHUB_OUTPUT
- name: Log in to Docker Hub
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
with:
images: hpcaitech/colossalai
- name: Build and push Docker image
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Push Docker image
run: |
docker push ${{ steps.build.outputs.tag }}

View File

@ -1,74 +1,29 @@
name: Release bdist wheel for Nightly versions
name: Publish Nightly Version to PyPI
on:
schedule:
# run at 00:00 of every Sunday
- cron: '0 0 * * 6'
workflow_dispatch:
jobs:
matrix_preparation:
name: Prepare Container List
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
run: |
matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]"
echo $matrix
echo "::set-output name=matrix::{\"container\":$(echo $matrix)}"
schedule:
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
build:
name: Release bdist wheels
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
runs-on: [self-hosted, gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
# cub is for cuda 10.2
- name: Copy scripts and checkout
run: |
cp -r ./.github/workflows/scripts/* ./
ln -s /github/home/pip_wheels ./pip_wheels
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
- name: Build bdist wheel
run: |
pip install beautifulsoup4 requests packaging
python ./build_colossalai_wheel.py --nightly
- name: 🚀 Deploy
uses: garygrossgarten/github-action-scp@release
with:
local: all_dist
remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
remove_old_build:
name: Remove old nightly build
jobs:
build-n-publish:
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI'
name: Build and publish Python 🐍 distributions 📦 to PyPI
runs-on: ubuntu-latest
needs: build
timeout-minutes: 20
steps:
- name: executing remote ssh commands using password
uses: appleboy/ssh-action@master
env:
BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }}
with:
host: ${{ secrets.PRIVATE_PYPI_HOST }}
username: ${{ secrets.PRIVATE_PYPI_USER }}
password: ${{ secrets.PRIVATE_PYPI_PASSWD }}
envs: BUILD_DIR
script: |
cd $BUILD_DIR
find . -type f -mtime +0 -exec rm -f {} +
script_stop: true
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.8.14'
- run: NIGHTLY=1 python setup.py sdist build
# publish to PyPI if executed on the main branch
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
verbose: true

View File

@ -1,21 +1,29 @@
name: Publish to PyPI
on: workflow_dispatch
on:
workflow_dispatch:
pull_request:
paths:
- 'version.txt'
types:
- closed
jobs:
build-n-publish:
if: github.ref_name == 'main' && github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' && github.event.pull_request.merged == true && github.base_ref == 'main'
name: Build and publish Python 🐍 distributions 📦 to PyPI
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.7.12'
python-version: '3.8.14'
- run: python setup.py sdist build
# publish to PyPI if executed on the main branch
# publish to Test PyPI if executed on the develop branch
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:

View File

@ -1,25 +0,0 @@
name: Publish to Test PyPI
on: workflow_dispatch
jobs:
build-n-publish:
if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor)
name: Build and publish Python 🐍 distributions 📦 to Test PyPI
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.7.12'
- run: python setup.py sdist build
# publish to PyPI if executed on the main branch
# publish to Test PyPI if executed on the develop branch
- name: Publish package to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository_url: https://test.pypi.org/legacy/
verbose: true

View File

@ -1,12 +1,13 @@
from filecmp import cmp
import requests
from bs4 import BeautifulSoup
import argparse
import os
import subprocess
from packaging import version
from filecmp import cmp
from functools import cmp_to_key
import requests
from bs4 import BeautifulSoup
from packaging import version
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
CUDA_HOME = os.environ['CUDA_HOME']

View File

@ -18,7 +18,7 @@ if [ $1 == "pip" ]
then
wget -nc -q -O ./pip_wheels/$filename $url
pip install ./pip_wheels/$filename
elif [ $1 == 'conda' ]
then
conda install pytorch==$torch_version cudatoolkit=$cuda_version $flags
@ -34,8 +34,9 @@ fi
python setup.py bdist_wheel
mv ./dist/* ./all_dist
# must remove build to enable compilation for
# cuda extension in the next build
rm -rf ./build
python setup.py clean
conda deactivate
conda env remove -n $python_version

View File

@ -0,0 +1,27 @@
import argparse
import os
def check_inputs(input_list):
for path in input_list:
real_path = os.path.join('examples', path)
if not os.path.exists(real_path):
return False
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
args = parser.parse_args()
name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list)
if is_correct:
print('success')
else:
print('failure')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,37 @@
import os
def show_files(path, all_files):
# Traverse all the folder/file in current directory
file_list = os.listdir(path)
# Determine the element is folder or file. If file, pass it into list, if folder, recurse.
for file_name in file_list:
# Get the abs directory using os.path.join() and store into cur_path.
cur_path = os.path.join(path, file_name)
# Determine whether folder
if os.path.isdir(cur_path):
show_files(cur_path, all_files)
else:
all_files.append(cur_path)
return all_files
def join(input_list, sep=None):
return (sep or ' ').join(input_list)
def main():
contents = show_files('examples/', [])
all_loc = []
for file_loc in contents:
split_loc = file_loc.split('/')
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) >= 4:
re_loc = '/'.join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,24 @@
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
for loc in name_list:
# Find only the sub-sub-folder of 'example' folder
# the examples folder structure is like
# - examples
# - area
# - application
# - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
folder_need_check.add('/'.join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))
if __name__ == '__main__':
main()

View File

@ -2,9 +2,10 @@
# coding: utf-8
import argparse
import requests
import re
import os
import re
import requests
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'

View File

@ -1,6 +1,6 @@
name: Synchronize Submodule
on:
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
@ -27,11 +27,11 @@ jobs:
- name: Commit update
run: |
git config --global user.name 'github-actions'
git config --global user.email 'github-actions@github.com'
git config --global user.name 'github-actions'
git config --global user.email 'github-actions@github.com'
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
git commit -am "Automated submodule synchronization"
- name: Create Pull Request
uses: peter-evans/create-pull-request@v3
with:
@ -43,4 +43,3 @@ jobs:
assignees: ${{ github.actor }}
delete-branch: true
branch: create-pull-request/patch-sync-submodule

17
.gitignore vendored
View File

@ -134,10 +134,23 @@ dmypy.json
.vscode/
# macos
.DS_Store
*.DS_Store
#data/
docs/.build
# pytorch checkpoint
*.pt
*.pt
# ignore version.py generated by setup.py
colossalai/version.py
# ignore any kernel build files
.o
.so
# ignore python interface defition file
.pyi
# ignore coverage test file
converage.lcov

View File

@ -27,4 +27,4 @@ sphinx:
python:
install:
- requirements: requirements/requirements.txt
- requirements: docs/requirements.txt
- requirements: docs/requirements.txt

View File

@ -1,4 +1,4 @@
Copyright 2021- The Colossal-ai Authors. All rights reserved.
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
@ -187,7 +187,7 @@ Copyright 2021- The Colossal-ai Authors. All rights reserved.
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright 2021- HPC-AI Technology Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,3 +1,4 @@
include *.txt README.md
recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
recursive-include op_builder *.py

View File

@ -1,14 +1,14 @@
# Colossal-AI
<div id="top" align="center">
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/)
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
Colossal-AI: 一个面向大模型时代的通用深度学习系统
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
<a href="https://www.colossalai.org/"> 文档 </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<h3> <a href="https://arxiv.org/abs/2110.14883"> 论文 </a> |
<a href="https://www.colossalai.org/"> 文档 </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> 例程 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
@ -22,41 +22,50 @@
</div>
## 新闻
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## 目录
<ul>
<li><a href="#为何选择-Colossal-AI">为何选择 Colossal-AI</a> </li>
<li><a href="#特点">特点</a> </li>
<li>
<a href="#并行训练样例展示">并行训练样例展示</a>
<a href="#并行训练样例展示">并行训练样例展示</a>
<ul>
<li><a href="#ViT">ViT</a></li>
<li><a href="#GPT-3">GPT-3</a></li>
<li><a href="#GPT-2">GPT-2</a></li>
<li><a href="#BERT">BERT</a></li>
<li><a href="#PaLM">PaLM</a></li>
<li><a href="#OPT">OPT</a></li>
<li><a href="#ViT">ViT</a></li>
<li><a href="#推荐系统模型">推荐系统模型</a></li>
</ul>
</li>
<li>
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
<a href="#单GPU训练样例展示">单GPU训练样例展示</a>
<ul>
<li><a href="#GPT-2-Single">GPT-2</a></li>
<li><a href="#PaLM-Single">PaLM</a></li>
</ul>
</li>
<li>
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
<a href="#推理-Energon-AI-样例展示">推理 (Energon-AI) 样例展示</a>
<ul>
<li><a href="#GPT-3-Inference">GPT-3</a></li>
<li><a href="#OPT-Serving">1750亿参数OPT在线推理服务</a></li>
<li><a href="#BLOOM-Inference">1750亿参数 BLOOM</a></li>
</ul>
</li>
<li>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI 成功案例</a>
<ul>
<li><a href="#xTrimoMultimer">xTrimoMultimer: 蛋白质单体与复合物结构预测</a></li>
<li><a href="#AIGC">AIGC: 加速 Stable Diffusion</a></li>
<li><a href="#生物医药">生物医药: 加速AlphaFold蛋白质结构预测</a></li>
</ul>
</li>
<li>
@ -69,11 +78,6 @@
<li><a href="#使用-Docker">使用 Docker</a></li>
<li><a href="#社区">社区</a></li>
<li><a href="#做出贡献">做出贡献</a></li>
<li><a href="#快速预览">快速预览</a></li>
<ul>
<li><a href="#几行代码开启分布式训练">几行代码开启分布式训练</a></li>
<li><a href="#构建一个简单的2维并行模型">构建一个简单的2维并行模型</a></li>
</ul>
<li><a href="#引用我们">引用我们</a></li>
</ul>
@ -98,6 +102,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行
- [序列并行](https://arxiv.org/abs/2105.13120)
- [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054)
- [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
- 异构内存管理
- [PatrickStar](https://arxiv.org/abs/2108.05818)
- 使用友好
@ -105,16 +110,11 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 推理
- [Energon-AI](https://github.com/hpcaitech/EnergonAI)
- Colossal-AI 成功案例
- [xTrimoMultimer: 蛋白质单体与复合物结构预测](https://github.com/biomap-research/xTrimoMultimer)
- 生物医药: [FastFold](https://github.com/hpcaitech/FastFold) 加速蛋白质结构预测 AlphaFold 训练与推理
<p align="right">(<a href="#top">返回顶端</a>)</p>
## 并行训练样例展示
### ViT
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
</p>
- 14倍批大小和5倍训练速度张量并行=64
### GPT-3
<p align="center">
@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/(updated)GPT-2.png" width=800>
- 用相同的硬件训练24倍大的模型
- 超3倍的吞吐量
- 超3倍的吞吐量
### BERT
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BERT.png" width=800/>
@ -145,10 +145,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型由于完全公开了预训练参数权重因此促进了下游任务和应用部署的发展。
- 加速45%仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
- 加速45%仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt)
请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。
### ViT
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
</p>
- 14倍批大小和5倍训练速度张量并行=64
### 推荐系统模型
- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), 使用软件Cache实现Embeddings用更少GPU显存训练更大的模型。
@ -178,7 +184,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 用相同的硬件训练34倍大的模型
<p align="right">(<a href="#top">back to top</a>)</p>
<p align="right">(<a href="#top">返回顶端</a>)</p>
## 推理 (Energon-AI) 样例展示
@ -195,23 +201,82 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- [OPT推理服务](https://service.colossalai.org/opt): 无需注册免费体验1750亿参数OPT在线推理服务
<p id="BLOOM-Inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
</p>
<p align="right">(<a href="#top">back to top</a>)</p>
- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): 降低1750亿参数BLOOM模型部署推理成本超10倍
<p align="right">(<a href="#top">返回顶端</a>)</p>
## Colossal-AI 成功案例
### xTrimoMultimer: 蛋白质单体与复合物结构预测
### AIGC
加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion)
<p id="diffusion_train" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20v2.png" width=800/>
</p>
- [训练](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 减少5.6倍显存消耗硬件成本最高降低46倍(从A100到RTX3060)
<p id="diffusion_demo" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/DreamBooth.png" width=800/>
</p>
- [DreamBooth微调](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): 仅需3-5张目标主题图像个性化微调
<p id="inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20Inference.jpg" width=800/>
</p>
- [推理](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): GPU推理显存消耗降低2.5倍
<p align="right">(<a href="#top">返回顶端</a>)</p>
### 生物医药
加速 [AlphaFold](https://alphafold.ebi.ac.uk/) 蛋白质结构预测
<p id="FastFold" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/FastFold.jpg" width=800/>
</p>
- [FastFold](https://github.com/hpcaitech/FastFold): 加速AlphaFold训练与推理、数据前处理、推理序列长度超过10000残基
<p id="xTrimoMultimer" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTM_Prediction.jpg" width=380/>
<p></p>
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTrimoMultimer_Table.jpg" width=800/>
</p>
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): 11倍加速蛋白质单体与复合物结构预测
<p align="right">(<a href="#top">返回顶端</a>)</p>
## 安装
### 从PyPI安装
您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包
```bash
pip install colossalai
```
但是如果你想在安装时就直接构建PyTorch扩展您可以设置环境变量`CUDA_EXT=1`.
```bash
CUDA_EXT=1 pip install colossalai
```
**否则PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。**
与此同时我们也每周定时发布Nightly版本这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。
```bash
pip install colossalai-nightly
```
### 从官方安装
您可以访问我们[下载](https://www.colossalai.org/download)页面来安装Colossal-AI在这个页面上发布的版本都预编译了CUDA扩展。
@ -231,10 +296,10 @@ pip install -r requirements/requirements.txt
pip install .
```
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
我们默认在`pip install`时不安装PyTorch扩展而是在运行时临时编译如果你想要提前安装这些扩展的话在使用融合优化器时会用到可以使用一下命令。
```shell
NO_CUDA_EXT=1 pip install .
CUDA_EXT=1 pip install .
```
<p align="right">(<a href="#top">返回顶端</a>)</p>
@ -283,31 +348,6 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
<p align="right">(<a href="#top">返回顶端</a>)</p>
## 快速预览
### 几行代码开启分布式训练
```python
parallel = dict(
pipeline=2,
tensor=dict(mode='2.5d', depth = 1, size=4)
)
```
### 几行代码开启异构训练
```python
zero = dict(
model_config=dict(
tensor_placement_policy='auto',
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=True
),
optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2)
)
```
<p align="right">(<a href="#top">返回顶端</a>)</p>
## 引用我们
@ -320,4 +360,4 @@ zero = dict(
}
```
<p align="right">(<a href="#top">返回顶端</a>)</p>
<p align="right">(<a href="#top">返回顶端</a>)</p>

163
README.md
View File

@ -1,14 +1,14 @@
# Colossal-AI
<div id="top" align="center">
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/)
[![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/)
Colossal-AI: A Unified Deep Learning System for Big Model Era
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
<a href="https://www.colossalai.org/"> Documentation </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> Examples </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
<h3> <a href="https://arxiv.org/abs/2110.14883"> Paper </a> |
<a href="https://www.colossalai.org/"> Documentation </a> |
<a href="https://github.com/hpcaitech/ColossalAI-Examples"> Examples </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
<a href="https://medium.com/@hpcaitech"> Blog </a></h3>
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
@ -17,46 +17,55 @@
[![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&amp)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
| [English](README.md) | [中文](README-zh-Hans.md) |
</div>
## Latest News
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
<ul>
<li><a href="#Why-Colossal-AI">Why Colossal-AI</a> </li>
<li><a href="#Features">Features</a> </li>
<li>
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
<ul>
<li><a href="#ViT">ViT</a></li>
<li><a href="#GPT-3">GPT-3</a></li>
<li><a href="#GPT-2">GPT-2</a></li>
<li><a href="#BERT">BERT</a></li>
<li><a href="#PaLM">PaLM</a></li>
<li><a href="#OPT">OPT</a></li>
<li><a href="#ViT">ViT</a></li>
<li><a href="#Recommendation-System-Models">Recommendation System Models</a></li>
</ul>
</li>
<li>
<a href="#Single-GPU-Training-Demo">Single GPU Training Demo</a>
<a href="#Single-GPU-Training-Demo">Single GPU Training Demo</a>
<ul>
<li><a href="#GPT-2-Single">GPT-2</a></li>
<li><a href="#PaLM-Single">PaLM</a></li>
</ul>
</li>
<li>
<a href="#Inference-Energon-AI-Demo">Inference (Energon-AI) Demo</a>
<a href="#Inference-Energon-AI-Demo">Inference (Energon-AI) Demo</a>
<ul>
<li><a href="#GPT-3-Inference">GPT-3</a></li>
<li><a href="#OPT-Serving">OPT-175B Online Serving for Text Generation</a></li>
<li><a href="#BLOOM-Inference">175B BLOOM</a></li>
</ul>
</li>
<li>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
<a href="#Colossal-AI-in-the-Real-World">Colossal-AI for Real World Applications</a>
<ul>
<li><a href="#xTrimoMultimer">xTrimoMultimer: Accelerating Protein Monomer and Multimer Structure Prediction</a></li>
<li><a href="#AIGC">AIGC: Acceleration of Stable Diffusion</a></li>
<li><a href="#Biomedicine">Biomedicine: Acceleration of AlphaFold Protein Structure</a></li>
</ul>
</li>
<li>
@ -69,11 +78,6 @@
<li><a href="#Use-Docker">Use Docker</a></li>
<li><a href="#Community">Community</a></li>
<li><a href="#contributing">Contributing</a></li>
<li><a href="#Quick-View">Quick View</a></li>
<ul>
<li><a href="#Start-Distributed-Training-in-Lines">Start Distributed Training in Lines</a></li>
<li><a href="#Write-a-Simple-2D-Parallel-Model">Write a Simple 2D Parallel Model</a></li>
</ul>
<li><a href="#Cite-Us">Cite Us</a></li>
</ul>
@ -100,8 +104,9 @@ distributed training and inference in a few lines.
- 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism
- [Sequence Parallelism](https://arxiv.org/abs/2105.13120)
- [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054)
- [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
- Heterogeneous Memory Management
- Heterogeneous Memory Management
- [PatrickStar](https://arxiv.org/abs/2108.05818)
- Friendly Usage
@ -110,17 +115,11 @@ distributed training and inference in a few lines.
- Inference
- [Energon-AI](https://github.com/hpcaitech/EnergonAI)
- Colossal-AI in the Real World
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): Accelerating Protein Monomer and Multimer Structure Prediction
- Colossal-AI in the Real World
- Biomedicine: [FastFold](https://github.com/hpcaitech/FastFold) accelerates training and inference of AlphaFold protein structure
<p align="right">(<a href="#top">back to top</a>)</p>
## Parallel Training Demo
### ViT
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
</p>
- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64
### GPT-3
<p align="center">
@ -150,10 +149,17 @@ distributed training and inference in a few lines.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT_update.png" width=800/>
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights.
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details.
### ViT
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/ViT.png" width="450" />
</p>
- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64
### Recommendation System Models
- [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), utilize software cache to train larger embedding tables with a smaller GPU memory budget.
@ -198,26 +204,85 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
<p id="BLOOM-Inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/BLOOM%20Inference.PNG" width=800/>
</p>
- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 175-billion-parameter BLOOM by more than 10 times.
<p align="right">(<a href="#top">back to top</a>)</p>
## Colossal-AI in the Real World
### xTrimoMultimer: Accelerating Protein Monomer and Multimer Structure Prediction
### AIGC
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).
<p id="diffusion_train" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20v2.png" width=800/>
</p>
- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060).
<p id="diffusion_demo" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/DreamBooth.png" width=800/>
</p>
- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject.
<p id="inference" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Stable%20Diffusion%20Inference.jpg" width=800/>
</p>
- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x.
<p align="right">(<a href="#top">back to top</a>)</p>
### Biomedicine
Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
<p id="FastFold" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/FastFold.jpg" width=800/>
</p>
- [FastFold](https://github.com/hpcaitech/FastFold): accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues.
<p id="xTrimoMultimer" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTM_Prediction.jpg" width=380/>
<p></p>
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/xTrimoMultimer_Table.jpg" width=800/>
</p>
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x
- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x.
<p align="right">(<a href="#top">back to top</a>)</p>
## Installation
### Install from PyPI
You can easily install Colossal-AI with the following command. **By defualt, we do not build PyTorch extensions during installation.**
```bash
pip install colossalai
```
However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
```bash
CUDA_EXT=1 pip install colossalai
```
**Otherwise, CUDA kernels will be built during runtime when you actually need it.**
We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch.
Installation can be made via
```bash
pip install colossalai-nightly
```
### Download From Official Releases
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions.
You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built PyTorch extensions.
### Download From Source
@ -228,17 +293,15 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
# install dependency
pip install -r requirements/requirements.txt
# install colossalai
pip install .
```
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
By default, we do not compile CUDA/C++ kernels. ColossalAI will build them during runtime.
If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
```shell
NO_CUDA_EXT=1 pip install .
CUDA_EXT=1 pip install .
```
<p align="right">(<a href="#top">back to top</a>)</p>
@ -289,32 +352,6 @@ Thanks so much to all of our amazing contributors!
<p align="right">(<a href="#top">back to top</a>)</p>
## Quick View
### Start Distributed Training in Lines
```python
parallel = dict(
pipeline=2,
tensor=dict(mode='2.5d', depth = 1, size=4)
)
```
### Start Heterogeneous Training in Lines
```python
zero = dict(
model_config=dict(
tensor_placement_policy='auto',
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=True
),
optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2)
)
```
<p align="right">(<a href="#top">back to top</a>)</p>
## Cite Us

View File

View File

@ -7,4 +7,11 @@ from .initialize import (
launch_from_torch,
)
__version__ = '0.1.11rc1'
try:
# .version will be created by setup.py
from .version import __version__
except ModuleNotFoundError:
# this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
__version__ = '0.0.0'
print('please install Colossal-AI from https://www.colossalai.org/download or from source')

View File

@ -1,7 +1,8 @@
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
from .apex_amp import ApexAMPOptimizer
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
r"""A helper function to wrap training components with Apex AMP modules

View File

@ -1,10 +1,13 @@
import inspect
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
from ._fp16_optimizer import FP16Optimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):

View File

@ -3,24 +3,33 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
try:
import colossal_C
from colossalai._C import fused_optim
except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
from torch.distributed import ProcessGroup
from .grad_scaler import BaseGradScaler
from ._utils import has_inf_or_nan, zero_gard_by_list
fused_optim = None
__all__ = ['FP16Optimizer']
def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
@ -33,7 +42,9 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
global fused_optim
load_fused_optim()
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
@ -41,7 +52,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
@ -73,8 +84,8 @@ class FP16Optimizer(Optimizer):
# get process group
def _get_process_group(parallel_mode):
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
return gpc.get_group(ParallelMode.DATA)
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode):
return gpc.get_group(parallel_mode)
else:
return None
@ -150,6 +161,12 @@ class FP16Optimizer(Optimizer):
f"==========================================",
ranks=[0])
@property
def max_norm(self):
"""Returns the maximum norm of gradient clipping.
"""
return self._clip_grad_max_norm
@property
def grad_scaler(self):
"""Returns the gradient scaler.

View File

@ -1,4 +1,5 @@
from typing import List
from torch import Tensor

View File

@ -1,12 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
from torch import Tensor
from typing import Dict
import torch
from torch import Tensor
from colossalai.logging import get_dist_logger
__all__ = ['BaseGradScaler']

View File

@ -1,10 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from .base_grad_scaler import BaseGradScaler
from typing import Optional
import torch
from .base_grad_scaler import BaseGradScaler
__all__ = ['DynamicGradScaler']

View File

@ -1,17 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import Any
from torch.optim import Optimizer
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.optimizer import ColossalaiOptimizer
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ReduceOp
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
return self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
if self.optim.max_norm == max_norm:
return
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed.")
class NaiveAMPModel(nn.Module):

View File

@ -1,10 +1,13 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
from typing import Optional
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.context import Config
from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,

View File

@ -3,16 +3,18 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import torch
from collections import defaultdict, abc
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from colossalai.context import ParallelMode
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from packaging import version
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class _MultiDeviceReplicator(object):

View File

@ -1,17 +1,17 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
class TorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate Pytorch AMP with an optimizer

View File

@ -0,0 +1,3 @@
from .ckpt_solver_base import CheckpointSolverBase
from .ckpt_solver_chen import CheckpointSolverChen
from .ckpt_solver_rotor import CheckpointSolverRotor

View File

@ -0,0 +1,16 @@
import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'rotorc',
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)

View File

@ -0,0 +1,195 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
n_dst.meta = n_src.meta
def _get_param_size(module: torch.nn.Module):
"""Get the size of the parameters in the module"""
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings:
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
# the solver is executed.
self.graph = deepcopy(graph)
self.graph.owning_module = graph.owning_module
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
self.node_list = self._linearize_graph()
else:
self.node_list = self.get_node_list()
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
def get_node_list(self):
"""Get the node list.
"""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
deps = {}
node_list = []
region = []
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
node_list.append(region)
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return node_list

View File

@ -0,0 +1,87 @@
import math
from copy import deepcopy
from typing import List, Set, Tuple
from torch.fx import Graph, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
"""
super().__init__(graph, 0, 0, True, cnode)
self.num_grids = num_grids
def solve(self) -> Graph:
"""Solve the checkpointing problem using Algorithm 3.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for idx, nodes in enumerate(self.node_list):
for n in nodes:
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and idx > prev_idx:
x += calculate_fwd_in(nodes[0])
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
ckpt_intv, b_approx = self.run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt

View File

@ -0,0 +1,197 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* computeTable(PyObject* self, PyObject* args) {
PyObject* chainParam;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
double* ftime = getDoubleArray(chainParam, "ftime");
if (!ftime) return NULL;
double* btime = getDoubleArray(chainParam, "btime");
if (!btime) return NULL;
long* x = getLongArray(chainParam, "x");
if (!x) return NULL;
long* xbar = getLongArray(chainParam, "xbar");
if (!xbar) return NULL;
long* ftmp = getLongArray(chainParam, "btmp");
if (!ftmp) return NULL;
long* btmp = getLongArray(chainParam, "btmp");
if (!btmp) return NULL;
long chainLength = PyObject_Length(chainParam);
if (!chainLength) return NULL;
#define COST_TABLE(m, i, l) \
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
double* costTable = (double*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
#define BACK_PTR(m, i, l) \
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
long* backPtr = (long*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chainLength; ++i)
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
COST_TABLE(m, i, i) = ftime[i] + btime[i];
else
COST_TABLE(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d;
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
}
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
for (long j = i + 1; j <= idx; ++j) {
sumFw += ftime[j - 1];
if (m >= x[j]) {
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
COST_TABLE(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= xbar[i + 1])
chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf;
} else {
COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1;
}
} else
COST_TABLE(m, i, idx) = INFINITY;
}
}
free(ftime);
free(btime);
free(x);
free(xbar);
free(ftmp);
free(btmp);
PyObject* pyCostTable = PyList_New(mmax + 1);
PyObject* pyBackPtr = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
for (long i = 0; i <= chainLength; ++i) {
PyObject* pyCostTable_m_i = PyDict_New();
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
PyObject* pyBackPtr_m_i = PyDict_New();
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
for (long l = i; l <= chainLength; ++l) {
PyObject* pyVar_l = PyLong_FromLong(l);
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l;
if (BACK_PTR(m, i, l) < 0)
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
else
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l);
}
}
}
free(costTable);
free(backPtr);
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
Py_DECREF(pyCostTable);
Py_DECREF(pyBackPtr);
return result;
}
static PyMethodDef rotorMethods[] = {
{"compute_table", computeTable, METH_VARARGS,
"Compute the optimal table with the rotor algorithm."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef rotorModule = {
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
"A simple implementation of dynamic programming algorithm rotor with C in "
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
rotorMethods};
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }

View File

@ -0,0 +1,441 @@
from copy import deepcopy
from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
calculate_bwd_time,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float, optional): Memory constraint for the solution, unit is byte.
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
self.memory_slots = memory_slots
# construct chain
unit = self.free_memory // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit)
self.cost_table = None
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain = self.chain
# compute cost table
if force_python:
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
if verbose:
self.print_chain()
# backtrack
try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
raise ValueError
if verbose:
self.print_sequence()
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
def print_sequence(self):
print(f'Sequence = {self.sequence}')
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
ftime, btime, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for node in node_list:
node_info = cls._extract_node_info(node)
ftime.append(node_info[0])
btime.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
btime.append(0)
btmp.append(0)
return Chain(ftime, btime, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
ftime = 0
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = fwd_mem_peak - xbar
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
"""Extract btmp from a list of nodes"""
def _extract_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
btmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return btmp
@staticmethod
def _compute_table(chain: Chain, mmax: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime = chain.ftime + [0.0]
btime = chain.btime
x = chain.x + [0]
xbar = chain.xbar + [0]
ftmp = chain.ftmp + [0]
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= xbar[i + 1]:
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
cost_table[m][i][idx] = best_leaf[1]
back_ptr[m][i][idx] = (False, best_leaf[0])
else:
cost_table[m][i][idx] = chain_checkpoint
back_ptr[m][i][idx] = (True,)
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
try:
from .rotorc import compute_table
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("rotorc has been built!", ranks=[0])
from .rotorc import compute_table
else:
logger.warning("rotorc built failed! Using python version!", ranks=[0])
return CheckpointSolverRotor._compute_table(chain, mmax)
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {budget}")
elif cost_table[budget][lhs][rhs] == float("inf"):
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
sequence = Sequence()
if rhs == lhs:
if lhs == len(chain):
sequence += [Loss()]
else:
sequence += [ForwardEnable(lhs), Backward(lhs)]
return sequence
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
Backward(lhs),
]
else:
best_leaf = back_ptr[budget][lhs][rhs][1]
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))

View File

@ -0,0 +1,184 @@
import math
from abc import ABC
from typing import Any, Iterable, List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
Args:
ftime (List[float]): The forward time of each node.
btime (List[float]): The backward time of each node.
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
ftmp (List[int]): The temporary forward memory of each node.
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
"""
self.ftime = ftime
self.btime = btime
self.x = x
self.xbar = xbar
self.ftmp = ftmp
self.btmp = btmp
if check_consistency and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1))
def __repr__(self):
chain_list = []
for i in range(len(self)):
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
i = len(self)
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
return chain_list.__repr__()
def __len__(self):
return len(self.ftime)
def discretize_all(self, unit: int):
"""Discretize the chain into a list of chains according to unit size."""
discretizer = lambda val: math.ceil(val / unit)
self.x = tree_map(discretizer, self.x)
self.xbar = tree_map(discretizer, self.xbar)
self.ftmp = tree_map(discretizer, self.ftmp)
self.btmp = tree_map(discretizer, self.btmp)
class Operation(ABC):
name = "Op"
def __repr__(self) -> str:
return f"{self.name}_{self.index}"
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Forward(Operation):
name = "F"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.ftime[self.index]
else:
return 1
class ForwardEnable(Forward):
name = "Fe"
class ForwardNograd(Forward):
name = "Fn"
class ForwardCheck(Forward):
name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
name = "B"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.btime[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
name = "MA"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
name = "WM"
class ReadMemory(MemoryAccess):
name = "RM"
class DiscardMemory(MemoryAccess):
name = "DM"
class Sequence(list):
def __init__(self):
super().__init__()
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list

View File

@ -0,0 +1,3 @@
from .meta_registry import *
from .metainfo import *
from .registry import meta_register

View File

@ -0,0 +1,15 @@
import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
# list of inplace operations
INPLACE_OPS = [torch.flatten]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]

View File

@ -0,0 +1,6 @@
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .linear import *
from .norm import *
from .pooling import *

View File

@ -0,0 +1,74 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["relu_meta_info"]
@meta_register.register(torch.nn.ReLU)
def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.ReLU metainfo generator
The aten graph of torch.nn.ReLU is
graph():
%input_2 : [#users=1] = placeholder[target=placeholder](default=)
%relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
%threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = kwargs.get("inplace", False)
# construct input args for forward
fwd_in_args = [input_tensor]
# construct input args for backward
bwd_in_args = [output_tensor]
# calculate cost
# the fwd op with compute cost is relu.default
# the bwd op with compute cost is threshold_backward
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: the inplace ReLU don't have forward memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
# of MetaInfoProp. In the future we might modify this part to make it clearer.
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,66 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@meta_register.register(BCAST_FUNC_OP)
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
activation=activation_size(fwd_in_args),
parameter=param_mem_cost,
)
# total cost
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,137 @@
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['convnd_meta_info']
@meta_register.register(torch.nn.Conv1d)
@meta_register.register(torch.nn.Conv2d)
@meta_register.register(torch.nn.Conv3d)
@meta_register.register(torch.nn.functional.conv1d)
@meta_register.register(torch.nn.functional.conv2d)
@meta_register.register(torch.nn.functional.conv3d)
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The atens graph of torch.nn.Convnd without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# check if conv has bias
if len(weight_tensors) > 1:
has_bias = True
# bias tensor's shape only has one dimension
if len(weight_tensors[0].shape) == 1:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor, bias_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
# construct input args for forward
fwd_args = [None] * 9
# weight and input
fwd_args[0] = input_tensor
fwd_args[1] = weight_tensor
fwd_args[2] = bias_tensor if has_bias else None
# transpose indicator should be set to False
fwd_args[6] = False
# construct input args for backward
bwd_args = [None] * 11
# weight and input
bwd_args[0] = output_tensor
bwd_args[1] = input_tensor
bwd_args[2] = weight_tensor
bwd_args[-1] = [True, True, True] if has_bias else [True, True, False]
# calculate cost
# the fwd op with compute cost is convolution.default
# the bwd op with compute cost is convolution_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
if has_bias else activation_size([input_tensor, weight_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,172 @@
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['linear_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The one without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = args[2].data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# process the dimension of input and output
if len(input_tensor.shape) > 2:
input_tensor: torch.Tensor
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
if len(output_tensor.shape) > 2:
output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(weight_tensors) > 1:
has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias:
# calculate cost with bias
# the fwd op with compute cost is addmm
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
else:
# calculate cost without bias
# the fwd op with compute cost is mm
# the bwd op with compute cost is mm * 2
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size(weight_tensor),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
parameter=activation_size(weight_tensor),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,103 @@
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['batchnormnd_meta_info']
@meta_register.register(torch.nn.BatchNorm1d)
@meta_register.register(torch.nn.BatchNorm2d)
@meta_register.register(torch.nn.BatchNorm3d)
def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
The aten graph of BatchNorm2d is like
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {})
%detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {})
%detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {})
%detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data
var_tensor = next(filter(lambda x: x.name == "running_var", args)).data
num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data
# construct fwd args
# the fwd inputs are input, weight, bias, running_mean, running_var and some other args
# indicating the status of the module
# the fwd outputs are output, saved mean, saved inv std and num batches tracked
fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5]
fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch]
# construct bwd args
# the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean,
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
# calculate cost
fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=activation_size([mean_tensor, var_tensor]))
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=activation_size([mean_tensor, var_tensor]),
buffer=activation_size([mean_tensor, var_tensor]))
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,134 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
@meta_register.register(torch.flatten)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = kwargs.get("inplace", False)
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# calculate cost
# the fwd op with compute cost is _adaptive_avg_pool2d.default
# the bwd op with compute cost is _adaptive_avg_pool2d_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.nn.MaxPool1d)
@meta_register.register(torch.nn.MaxPool2d)
@meta_register.register(torch.nn.MaxPool3d)
def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for MaxPool
The aten graph of MaxPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {})
%max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {})
%detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# construct index matrix
index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64)
# calculate cost
# the fwd op with compute cost is max_pool2d_with_indices.default
# the bwd op with compute cost is max_pool2d_with_indices_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: the index matrix will be discarded in backward phase
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
temp=activation_size(index_matrix))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,117 @@
from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
class MetaInfo:
"""MetaInfo class
This class is used to store meta info based on sharding strategy and the given
target function.
"""
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
# compute cost of forward and backward computation
self.compute_cost: TrainCycleItem
# compute memory cost of forward and backward phase
self.memory_cost: TrainCycleItem
# list of input tensors
self.fwd_in: List[torch.Tensor]
# list of buffer tensors
self.fwd_buffer: List[torch.Tensor]
# list of output tensors
self.fwd_out: List[torch.Tensor]
# sharding strategy
self._strategy = strategy
# target function
self._target = target
# compute metainfo if possible
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
@property
def strategy(self) -> ShardingStrategy:
return self._strategy
@property
def target(self) -> Callable:
return self._target
@strategy.setter
def strategy(self, strategy: ShardingStrategy) -> None:
self._strategy = strategy
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
@target.setter
def target(self, target: Callable) -> None:
self._target = target
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
def compute_metainfo(self):
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
else:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
# process corner case for NO_SAVE_ACTIVATION
if not save_fwd_in:
self.fwd_in = []

View File

@ -0,0 +1,32 @@
__all__ = ['Registry']
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_register = Registry('meta')

View File

@ -0,0 +1,113 @@
from typing import Dict
import torch
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> MetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
meta_info = MetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
# get tensor shape for MetaInfo
origin_sharding_spec: ShardingSpec
target_sharding_spec: ShardingSpec
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
# extract node_index and op_data_name
node_index, op_data_name = node.args[2], node.args[3]
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
# this case is for all_reduce, there will be no memory cost
meta_info = MetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm

View File

@ -0,0 +1,8 @@
import torch
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]

View File

@ -0,0 +1,165 @@
import uuid
from dataclasses import asdict
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
'placeholder': self.placeholder_handler,
'get_attr': self.get_attr_handler,
'output': self.output_handler,
'call_function': self.node_handler,
'call_module': self.node_handler,
'call_method': self.node_handler,
}
def _set_data_ptr(self, x):
"""
Set uuid to tensor
"""
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
if node.op == 'call_module':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
return False
def run(self) -> GraphModule:
"""
Run the meta information propagation pass on the module.
"""
for node in self.module.graph.nodes:
node: Node
self.func_dict[node.op](node)
@compatibility(is_backward_compatible=False)
def placeholder_handler(self, node: Node) -> None:
"""
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def get_attr_handler(self, node: Node) -> None:
"""
Handle the get_attr node.
"""
graph_info = GraphInfo()
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def output_handler(self, node: Node) -> None:
"""
Handle the output node.
"""
graph_info = GraphInfo()
output_tensors = []
for par in node._input_nodes:
if par.meta:
output_tensors += par.meta["fwd_out"]
graph_info.fwd_in = output_tensors
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
# set data_ptr for input_tensor in MetaInfo class
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr
else:
for par in node._input_nodes:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# set data_ptr for buffer_tensor
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# set data_ptr for output_tensor
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}

View File

@ -0,0 +1,221 @@
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into
the user node expected form.
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
rst = type(node)(rst)
return rst
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
"""
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
def _preprocess_graph(nodes: List[Node]):
"""
This method is used to extract all the placeholders with sharding information,
and mapping the nodes into the index of the origin graph.
"""
# mapping the node into the origin graph index
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if node.target == 'comm_actions_dict':
comm_actions_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
index += 1
return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict
def _shape_consistency_apply(gm: torch.fx.GraphModule):
"""
This pass is used to add the shape consistency node to the origin graph.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs
return gm
def _comm_spec_apply(gm: torch.fx.GraphModule):
"""
This pass is used to add the comm spec apply node to the origin graph.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = tuple(new_args)
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with comm_spec_apply_node
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm
def runtime_apply_pass(gm: torch.fx.GraphModule):
"""
The method manages all the passes acting on the distributed training runtime.
"""
gm = _shape_consistency_apply(gm)
gm = _comm_spec_apply(gm)
return gm

View File

@ -0,0 +1,471 @@
import operator
from copy import deepcopy
from typing import Dict, List, Union
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
if target_dim is not None:
assert isinstance(size, int)
if target_dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[target_dim]:
total_shard_size *= device_mesh_info[shard_dim]
size = size * total_shard_size
else:
size = list(size)
for dim, dim_size in enumerate(size):
if dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[dim]:
total_shard_size *= device_mesh_info[shard_dim]
size[dim] = dim_size * total_shard_size
size = torch.Size(size)
return size
def _solution_annotatation(gm: torch.fx.GraphModule,
solution: List[int],
strategies_constructor: StrategiesConstructor = None):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if strategies_constructor is not None:
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
else:
nodes = tuple(mod_graph.nodes)
no_strategy_nodes = []
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
if hasattr(node, 'metainfo_vector'):
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes
comm_actions_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
for user_node in node.strategies_vector.successor_nodes:
if user_node in no_strategy_nodes:
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
else:
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(node_name)
origin_pending_strategy = node.best_strategy
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
new_communication_actions = {}
if op_data_in_user in user_strategy.communication_actions:
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
new_communication_action.arg_index = 0
new_communication_actions[origin_op_data] = new_communication_action
node.best_strategy.communication_actions = new_communication_actions
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
comm_actions_dict[index] = comm_action_dict
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
torch.reshape, etc. These nodes have enough information like input sharding_spec and
output sharding_spec to decide how to convert the size value.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
node_pairs = {}
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
# there are two usages of torch.Tensor.size:
# tensor.size()
# tensor.size(dim)
# if a target_dim is assigned, then the output will be
# in type of int, instead of torch.Size
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
if node.op == 'call_function' and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
# so we do not create a node for slice object. On the other side,
# slice object could take fx.Node as its argument. And the user
# relationship cannot be tracked in fx graph.
# Therefore, I record the node_pairs in this pass, and use the it
# to replace the original node argument inside the slice object if
# it has been processed in above pass.
# There are three main usages of operator.getitem:
# getitem(input, int)
# getitem(input, slice)
# getitem(input, Tuple[slice])
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
if getitem_index.start in node_pairs:
new_start = node_pairs[getitem_index.start]
elif getitem_index.stop in node_pairs:
new_stop = node_pairs[getitem_index.stop]
elif getitem_index.step in node_pairs:
new_step = node_pairs[getitem_index.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_args = (node.args[0], new_slice_item)
node.args = new_args
elif isinstance(getitem_index, (tuple, list)):
if not isinstance(getitem_index[0], slice):
continue
new_slice_items = []
for slice_item in getitem_index:
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
new_stop = node_pairs[slice_item.stop]
elif slice_item.step in node_pairs:
new_step = node_pairs[slice_item.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
node.args = new_args
return gm
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if isinstance(new_args[1], int):
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
else:
new_args[dim + 1] //= total_shard_size
else:
new_args[1] = list(new_args[1])
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
else:
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
target = node.target
# process the node with (input, torch.Size) style args
if target in (torch.reshape,):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.append(list(arg._meta_data))
else:
new_args.append(arg)
else:
assert isinstance(
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
new_args.append(list(arg))
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
continue
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
if attr_len == 1:
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root
for atom in atoms[:-1]:
target_module = getattr(target_module, atom)
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use)
return gm
def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict

View File

@ -1,6 +1,7 @@
import torch
import operator
import torch
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
@ -25,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.transpose,
torch.split,
torch.permute,
operator.getitem,
]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
@ -35,7 +43,7 @@ RESHAPE_METHOD_OP = [
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,

View File

@ -1,6 +1,6 @@
from .options import SolverOptions
from .strategies_constructor import StrategiesConstructor
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .solver import Solver
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor

View File

@ -5,10 +5,11 @@ from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import INFINITY_COST
@ -17,7 +18,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
@ -58,7 +59,7 @@ def generate_resharding_costs(nodes: List[Node],
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}

View File

@ -3,9 +3,9 @@ import warnings
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector()
for strategy in conv_handler.strategies_vector:
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
Output:
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}

View File

@ -6,9 +6,9 @@ from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from .operator_handler import OperatorHandler
@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator):
class MatMulStrategyGenerator(StrategyGenerator):
"""
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed.
"""
@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
# create and register strategy
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,

View File

@ -2,10 +2,14 @@ import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size, ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
@ -63,19 +67,19 @@ class LayerNormHandler(OperatorHandler):
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
@ -216,7 +220,7 @@ class LayerNormHandler(OperatorHandler):
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0

View File

@ -1,14 +1,16 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from webbrowser import Opera
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from torch.fx.node import Node
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from .._utils import generate_resharding_costs, generate_sharding_spec
from ..sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
@ -60,7 +62,7 @@ class OperatorHandler(ABC):
@abstractmethod
def register_strategy(self) -> StrategiesVector:
"""
Register
Register
"""
pass

View File

@ -4,9 +4,9 @@ import warnings
from copy import deepcopy
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec

View File

@ -1,18 +1,21 @@
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ._utils import generate_resharding_costs, generate_sharding_spec
from .constants import *
from .op_handler import *
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .op_handler import *
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins
__all__ = ['StrategiesConstructor']

View File

@ -0,0 +1,275 @@
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
'''
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
'''
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
'''
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
self.origin_spec_dict = origin_spec_dict
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
return self.module(*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs)
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
'''
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
'''
# TODO: implement this function
pass
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
'''
This method is used to search the best logical mesh shape for the given world size
based on the alpha_beta_dict.
For example:
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
'''
# TODO: implement this function
return (world_size, 1)
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
'''
# TODO: implement this function
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
return strategies_constructor
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
return solution
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
return gm, sharding_spec_dicts
def initialize_device_mesh(world_size: int = -1,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None):
'''
This method is used to initialize the device mesh.
Args:
world_size(optional): the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
device1d = [i for i in range(world_size)]
if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
alpha_beta_dict = profile_alpha_beta(device1d)
if logical_mesh_shape is None:
# search for the best logical mesh shape
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
physical_mesh = torch.tensor(device1d)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
mesh_shape=logical_mesh_shape,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
return device_mesh
def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False):
'''
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
model: the model to be sharded.
meta_args: the meta_args is used to specify the input shapes of the model.
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
if load_solver_solution:
solution = torch.load(solution_path)
else:
solution = solve_solution(gm, strategies_constructor, memory_budget)
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution:
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
return model_to_return, solution_to_return
else:
return model_to_return
def autoparallelize(model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0):
'''
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
Args:
model: the model to be sharded.
meta_args(optional): the meta_args is used to specify the input shapes of the model.
If the meta_args is None, the meta_args will be extracted from the data_loader.
data_loader(optional): the data_loader to be used in normal training loop.
data_process_func(optional): the data_process_func is used to process the data from the data_loader.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solver_solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
if return_solution:
model, solution = rst_to_unpack
return model, solution
else:
model = rst_to_unpack
return model

View File

@ -1,19 +1,31 @@
from .addmm_handler import ADDMMFunctionHandler
from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler
from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OutputHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'operator_registry'
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
]

View File

@ -0,0 +1,91 @@
from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['ADDMMFunctionHandler']
@operator_registry.register(torch.addmm)
@operator_registry.register(torch.Tensor.addmm)
class ADDMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
if isinstance(tensor, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# input operand
input_data = self.node.args[1]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[1]),
type=self._infer_op_data_type(input_data),
data=input_data)
# other operand
other_data = self.node.args[2]._meta_data
physical_other_operand = OperationData(name=str(self.node.args[2]),
type=self._infer_op_data_type(other_data),
data=other_data)
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
physical_bias_operand = OperationData(name=str(self.node.args[0]),
type=self._infer_op_data_type(bias_data),
data=bias_data,
logical_shape=bias_logical_shape)
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
'bias': physical_bias_operand
}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
bias_op_data = op_data_mapping['bias']
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy

View File

@ -2,8 +2,10 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
@operator_registry.register(torch.nn.BatchNorm1d)
@operator_registry.register(torch.nn.BatchNorm2d)
@operator_registry.register(torch.nn.BatchNorm3d)
class BatchNormModuleHandler(ModuleHandler):
class BatchNormModuleHandler(MetaInfoModuleHandler):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""

View File

@ -0,0 +1,96 @@
from typing import Dict, List, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
__all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
bcast_shape = self.node._meta_data.shape
def _get_op_data_type(tensor):
if isinstance(tensor, torch.nn.parameter.Parameter):
return OperationDataType.PARAM
else:
return OperationDataType.ARG
def _get_arg_value(idx):
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
return meta_data
input_meta_data = _get_arg_value(0)
other_meta_data = _get_arg_value(1)
output_meta_data = self.node._meta_data
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
logical_shape=bcast_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=_get_op_data_type(other_meta_data),
data=other_meta_data,
logical_shape=bcast_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
if not isinstance(op_data.data, torch.Tensor):
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)
physical_shape = op_data.data.shape
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=op_data,
sharding_spec=sharding_spec)
strategy.communication_actions[op_data] = comm_action
return strategy

View File

@ -2,8 +2,10 @@ from typing import Dict, List, Union
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
bias_physical_shape)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy

View File

@ -3,9 +3,9 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import transpose_partition_dim
from .node_handler import ModuleHandler, NodeHandler
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
@operator_registry.register(torch.nn.Conv2d)
@operator_registry.register(torch.nn.Conv3d)
class ConvModuleHandler(ModuleHandler):
class ConvModuleHandler(MetaInfoModuleHandler):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
@operator_registry.register(F.conv1d)
@operator_registry.register(F.conv2d)
@operator_registry.register(F.conv3d)
class ConvFunctionHandler(NodeHandler):
class ConvFunctionHandler(MetaInfoNodeHandler):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""

View File

@ -0,0 +1,230 @@
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_output_dims = output_op_data.data.dim() - 1
# get logger for debug message
logger = get_dist_logger()
# For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for
# logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the
# physical input shape. Thus, we enumerate to get all possible cases.
if input_sharding_spec.dim_partition_dict:
# if bool(input_sharding_spec.dim_partition_dict), it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Embedding)
class EmbeddingModuleHandler(ModuleHandler):
"""
A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'])
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(F.embedding)
class EmbeddingFunctionHandler(NodeHandler):
"""
A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In F.embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data)
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies

View File

@ -0,0 +1,10 @@
from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
from .split_handler import SplitHandler
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler
__all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
'SplitHandler', 'SplitGenerator'
]

View File

@ -0,0 +1,76 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import PermuteGenerator
__all__ = ['PermuteHandler']
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.permute)
class PermuteHandler(NodeHandler):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
if self.node.op == 'call_method':
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
permute_dims.append(arg)
else:
# torch.permute (input, dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, (tuple, list)):
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
arg,
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
for i in range(num_dims):
# recover negative value to positive
if permute_dims[i] < 0:
permute_dims[i] += num_dims
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -0,0 +1,299 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -0,0 +1,63 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import SplitGenerator
__all__ = ['SplitHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
class SplitHandler(NodeHandler):
"""
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
split_size = self.node.args[1]
if len(self.node.args) == 3:
# (input, split_size, split_dim)
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
else:
split_dim = 0
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if split_dim < 0:
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -0,0 +1,65 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import TransposeGenerator
__all__ = ['TransposeHandler']
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.transpose)
class TransposeHandler(NodeHandler):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
transpose_dims = []
# torch.transpose (input, dim0, dim1)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
transpose_dims.append(arg._meta_data)
else:
transpose_dims.append(arg)
num_dims = self.node._meta_data.dim()
for i in range(2):
# recover negative value to positive
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims',
type=OperationDataType.ARG,
data=list(transpose_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -0,0 +1,53 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import ViewGenerator
__all__ = ['ViewHandler']
@operator_registry.register(torch.Tensor.reshape)
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.view)
class ViewHandler(NodeHandler):
"""
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -0,0 +1,34 @@
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
__all__ = ['GetattrHandler']
class GetattrHandler(NodeHandler):
"""
A GetattrHandler which deals with the sharding strategies for Getattr Node.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(GetattrGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# There are only two possible types for get_attr node:
# 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers)
# 2. torch.nn.Module
# temporarily, we just support first case in Tracer, so we don't have to worry about
# issue related to the node._meta_data type.
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"output": physical_output}
return mapping

View File

@ -6,7 +6,7 @@ import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
__all__ = ['GetItemHandler']

View File

@ -3,12 +3,16 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.auto_parallel.tensor_shard.utils import (
check_sharding_spec_validity,
transpose_partition_dim,
update_partition_dim,
)
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
@ -28,9 +32,11 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
transpose_partition_dim(sharding_spec, 0, -1)
assert op_data.logical_shape[0] == op_data.data.shape[1] and \
op_data.logical_shape[1] == op_data.data.shape[0], \
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
@ -54,6 +60,23 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_input_dims = len(input_op_data.logical_shape) - 1
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_input_dims = input_op_data.data.dim() - 1
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
else:
input_last_dim_mapping = {}
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
output_last_dim_mapping = {}
# get logger for debug message
logger = get_dist_logger()
@ -73,14 +96,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
@ -95,12 +125,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
@ -108,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
class LinearModuleHandler(MetaInfoModuleHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
@ -116,7 +151,8 @@ class LinearModuleHandler(ModuleHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@ -167,15 +203,16 @@ class LinearModuleHandler(ModuleHandler):
@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
class LinearFunctionHandler(MetaInfoNodeHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@ -198,27 +235,34 @@ class LinearFunctionHandler(NodeHandler):
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None:
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]),
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.args[2]._meta_data)
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input

View File

@ -0,0 +1,486 @@
import operator
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
BroadcastType,
get_broadcast_dim_info,
get_broadcast_shape,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
StrategyGenerator,
)
class MatMulType(Enum):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
BMM = 3
def get_matmul_type(input_dim: int, other_dim: int):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
elif input_dim in [1, 2] and other_dim == 2:
matmul_type = MatMulType.MM
elif input_dim == 2 and other_dim == 1:
matmul_type = MatMulType.MV
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
matmul_type = MatMulType.BMM
else:
raise ValueError(
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
)
return matmul_type
class BmmTransform(ABC):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@abstractmethod
def apply(self, shape_mapping: Dict[str, List[int]]):
pass
@abstractmethod
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
pass
class Padder(BmmTransform):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
# keep the padding dim, op_name -> padded_dim
self.padded_dim_mapping = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
dim_partition_list = [None] * len(tensor_shape)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim = self.padded_dim_mapping[key]
# compute the new dim partition
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
dim_partition_list[tensor_dim] = mesh_dims
dim_partition_list.pop(padded_dim)
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Broadcaster(BmmTransform):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
self.broadcast_dim_info = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
# broadcast the batch dim and record
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# remove sharding on the broadcast dim
def _remove_sharding_on_broadcast_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
if broadcast_type == BroadcastType.MULTIPLE:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Viewer(BmmTransform):
"""
Change the shape of the tensor from N-D to 3D
"""
def __init__(self) -> None:
self.batch_dims_before_view = None
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# get operation data
def _update_sharding_spec(key, strategy, physical_batch_dim):
"""
Map the logical batch dim to the physical batch dim
"""
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
num_batch_dim_before_view = len(self.batch_dims_before_view)
# enumerate all sharding strategies
strategies = []
for i in range(num_batch_dim_before_view):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
continue
return strategies
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
return input_shape, other_shape, output_shape
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# check which type of operation this matmul will call
self.input_meta_data = self.node.args[0]._meta_data
self.other_meta_data = self.node.args[1]._meta_data
self.output_meta_data = self.node._meta_data
input_dim = self.input_meta_data.dim()
other_dim = self.other_meta_data.dim()
self.matmul_type = get_matmul_type(input_dim, other_dim)
if self.matmul_type == MatMulType.BMM:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self.transforms = [Padder(), Broadcaster(), Viewer()]
else:
self.transforms = None
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
if self.matmul_type == MatMulType.BMM:
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.DOT:
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape_func = {
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
return op_data_mapping
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
# convert list to torch.Size
if input_logical_shape:
input_logical_shape = torch.Size(input_logical_shape)
if other_logical_shape:
other_logical_shape = torch.Size(other_logical_shape)
if output_logical_shape:
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return None, None, None
def _get_logical_shape_for_mm(self):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if self.input_meta_data.dim() == 1:
input_logical_shape = [1] + list(self.input_meta_data.shape)
input_logical_shape = torch.Size(input_logical_shape)
else:
input_logical_shape = None
return input_logical_shape, None, None
def _get_logical_shape_for_mv(self):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return None, None, None
def _get_logical_shape_for_bmm(self):
input_physical_shape = list(self.input_meta_data.shape)
other_physical_shape = list(self.other_meta_data.shape)
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
return strategy
elif self.matmul_type == MatMulType.MM:
if self.input_meta_data.dim() == 1:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
input_physical_shape = self.node.args[0]._meta_data.shape
dim_partition_dict = input_sharding_spec.dim_partition_dict
# remove the partitioning in the dim 0
if 0 in dim_partition_dict:
dim_partition_dict.pop(0, None)
# move the partitioning in dim 1 to dim 0
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
return strategy
else:
return strategy
elif self.matmul_type == MatMulType.BMM:
op_data_mapping = self.get_operation_data_mapping()
strategies = [strategy]
# recover the physical sharding spec
for transform in self.transforms[::-1]:
recovered_stragies = []
for strategy_ in strategies:
output = transform.recover(op_data_mapping, strategy_)
if isinstance(output, ShardingStrategy):
recovered_stragies.append(output)
elif isinstance(output, (list, tuple)):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
return strategies

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
@ -49,7 +52,16 @@ class NodeHandler(ABC):
for node in self.predecessor_node:
node_name = str(node)
# get the current sharding spec generated by this node handler
# we will not compute the resharding costs for the node not counted in the strategy.
# And the node with tuple or list output need to be handled below.
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
if str(node) not in node_in_strategy:
continue
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
@ -59,27 +71,83 @@ class NodeHandler(ABC):
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# create data structrure to store costs
if op_data not in resharding_costs:
if node not in resharding_costs:
resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if prev_sharding_spec is None:
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(prev_sharding_spec, ShardingSpec):
if isinstance(data, torch.Tensor):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
return resharding_cost
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
_, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec,
current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
bwd=resharding_cost["backward"],
total=resharding_cost["total"])
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
resharding_costs[node].append(resharding_cost)
strategy.resharding_costs = resharding_costs
return strategy
def get_target_function(self) -> callable:
"""
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if self.node.op in ('placeholder', 'get_attr', 'output'):
return None
if self.node.op == 'call_module':
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
raise ValueError(f'Unsupported node type: {self.node.op}')
return target
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
@ -151,6 +219,38 @@ class NodeHandler(ABC):
pass
class MetaInfoNodeHandler(NodeHandler):
"""
This is a base class to handle the nodes patched in the meta profiler.
Note: this class will be integrated into the NodeHandler class in the future, after
all the functions are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
@ -168,3 +268,35 @@ class ModuleHandler(NodeHandler):
self.module = module
self.named_parameters = named_parameters
self.named_buffers = named_buffers
class MetaInfoModuleHandler(ModuleHandler):
"""
This is a base class to handle the module patched in the meta profiler.
Note: this class will be integrated into the ModuleHandler class in the future, after
all the modules are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector

View File

@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
class NormPoolingHandler(MetaInfoModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""

View File

@ -2,38 +2,51 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler']
__all__ = ['OutputHandler']
class OuputHandler(NodeHandler):
class OutputHandler(NodeHandler):
"""
A OuputHandler which deals with the sharding strategies for Output Node.
A OutputHandler which deals with the sharding strategies for Output Node.
"""
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
output_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
dummy_output = torch.empty(1,).to("meta")
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output)
mapping = {"output": physical_output}
mapping = {}
output_meta_data = []
for index, input_node in enumerate(self.predecessor_node):
if not hasattr(input_node, "_meta_data"):
print(input_node.name)
physical_inputs = OperationData(name=str(input_node),
type=OperationDataType.ARG,
data=input_node._meta_data)
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
name_key = f'input_{index}'
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
output_meta_data = tuple(output_meta_data)
self.node._meta_data = output_meta_data
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping["output"] = physical_output
return mapping

View File

@ -1,21 +1,31 @@
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler']
__all__ = ['PlaceholderHandler']
class PlacehodlerHandler(NodeHandler):
class PlaceholderHandler(NodeHandler):
"""
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
placeholder_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
generators.append(
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:

View File

@ -8,7 +8,12 @@ class Registry:
def register(self, source):
def wrapper(func):
self.store[source] = func
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
return wrapper

View File

@ -3,18 +3,17 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
@ -25,13 +24,47 @@ class ReshapeHandler(NodeHandler):
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if isinstance(data, torch.Tensor):
return data.shape
else:
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
logical_shape = []
for tensor in data:
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
logical_shape.append(tensor.shape)
logical_shape = tuple(logical_shape)
return logical_shape
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
type=data_type,
data=input_data,
logical_shape=input_logical_shape)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output}

View File

@ -0,0 +1,55 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
__all__ = ['SoftmaxHandler']
@operator_registry.register(torch.nn.Softmax)
@operator_registry.register(torch.nn.functional.softmax)
class SoftmaxHandler(NodeHandler):
"""
A SoftmaxHandler which deals with the sharding strategies for
torch.nn.Softmax or torch.nn.functional.softmax.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
softmax_dim = self.node.kwargs['dim']
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
"output": physical_output_operand
}
return mapping

Some files were not shown because too many files have changed in this diff Show More