diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 845f7af06..c09c10308 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -20,6 +20,8 @@ body: A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. placeholder: | A clear and concise description of what the bug is. validations: diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml index e31193ba8..511997e2e 100644 --- a/.github/ISSUE_TEMPLATE/documentation.yml +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -17,6 +17,7 @@ body: **Expectation** What is your expected content about it? **Screenshots** If applicable, add screenshots to help explain your problem. **Suggestions** Tell us how we could improve the documentation. + **Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. placeholder: | A clear and concise description of the issue. validations: diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 8dcc51ea8..d05bc25f6 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -22,6 +22,8 @@ body: If applicable, add screenshots to help explain your problem. **Suggest a potential alternative/fix** Tell us how we could improve this project. + **Optional: Affiliation** + Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. placeholder: | A clear and concise description of your idea. validations: diff --git a/.github/ISSUE_TEMPLATE/proposal.yml b/.github/ISSUE_TEMPLATE/proposal.yml index 6ca7bd1a0..614ef7775 100644 --- a/.github/ISSUE_TEMPLATE/proposal.yml +++ b/.github/ISSUE_TEMPLATE/proposal.yml @@ -13,6 +13,7 @@ body: - Bumping a critical dependency's major version; - A significant improvement in user-friendliness; - Significant refactor; + - Optional: Affiliation/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. - ... Please note this is not for feature request or bug template; such action could make us identify the issue wrongly and close it without doing anything. @@ -43,4 +44,4 @@ body: - type: markdown attributes: value: > - Thanks for contributing 🎉! \ No newline at end of file + Thanks for contributing 🎉! diff --git a/.github/reviewer_list.yml b/.github/reviewer_list.yml deleted file mode 100644 index ce1d4849f..000000000 --- a/.github/reviewer_list.yml +++ /dev/null @@ -1,9 +0,0 @@ -addReviewers: true - -addAssignees: author - -numberOfReviewers: 1 - -reviewers: - - frankleeeee - - kurisusnowdeng diff --git a/.github/workflows/assign_reviewer.yml b/.github/workflows/assign_reviewer.yml deleted file mode 100644 index 6ebb33982..000000000 --- a/.github/workflows/assign_reviewer.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: Assign Reviewers for Team - -on: - pull_request: - types: [opened] - -jobs: - assign_reviewer: - name: Assign Reviewer for PR - runs-on: ubuntu-latest - if: | - github.event.pull_request.draft == false && github.base_ref == 'main' - && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' - && toJson(github.event.pull_request.requested_reviewers) == '[]' - steps: - - uses: kentaro-m/auto-assign-action@v1.2.1 - with: - configuration-path: '.github/reviewer_list.yml' diff --git a/.github/workflows/auto_example_check.yml b/.github/workflows/auto_example_check.yml new file mode 100644 index 000000000..7f1e357e3 --- /dev/null +++ b/.github/workflows/auto_example_check.yml @@ -0,0 +1,130 @@ +name: Test Example +on: + pull_request: + # any change in the examples folder will trigger check for the corresponding example. + paths: + - 'examples/**' + # run at 00:00 of every Sunday(singapore time) so here is UTC time Saturday 16:00 + schedule: + - cron: '0 16 * * 6' + +jobs: + # This is for changed example files detect and output a matrix containing all the corresponding directory name. + detect-changed-example: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.setup-matrix.outputs.matrix }} + anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} + name: Detect changed example files + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + - name: Get all changed example files + id: changed-files + uses: tj-actions/changed-files@v35 + # Using this can trigger action each time a PR is submitted. + with: + since_last_remote_commit: true + - name: setup matrix + id: setup-matrix + run: | + changedFileName="" + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + changedFileName="${file}:${changedFileName}" + done + echo "$changedFileName was changed" + res=`python .github/workflows/scripts/example_checks/detect_changed_example.py --fileNameList $changedFileName` + echo "All changed examples are $res" + + if [ "$x" = "[]" ]; then + echo "anyChanged=false" >> $GITHUB_OUTPUT + echo "matrix=null" >> $GITHUB_OUTPUT + else + dirs=$( IFS=',' ; echo "${res[*]}" ) + echo "anyChanged=true" >> $GITHUB_OUTPUT + echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT + fi + + # If no file is changed, it will prompt an error and shows the matrix do not have value. + check-changed-example: + # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + name: Test the changed example + needs: detect-changed-example + runs-on: [self-hosted, gpu] + strategy: + matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/examples-data:/data/ + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Install Colossal-AI + run: | + pip install -v . + - name: Test the example + run: | + example_dir=${{ matrix.directory }} + cd "${PWD}/examples/${example_dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 + + # This is for all files' weekly check. Specifically, this job is to find all the directories. + matrix_preparation: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule' + name: Prepare matrix for weekly check + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.setup-matrix.outputs.matrix }} + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: setup matrix + id: setup-matrix + run: | + res=`python .github/workflows/scripts/example_checks/check_example_weekly.py` + all_loc=$( IFS=',' ; echo "${res[*]}" ) + echo "Found the examples: $all_loc" + echo "matrix={\"directory\":$(echo "$all_loc")}" >> $GITHUB_OUTPUT + + weekly_check: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'schedule' + name: Weekly check all examples + needs: matrix_preparation + runs-on: [self-hosted, gpu] + strategy: + matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + timeout-minutes: 10 + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: Install Colossal-AI + run: | + pip install -v . + - name: Traverse all files + run: | + example_dir=${{ matrix.diretory }} + echo "Testing ${example_dir} now" + cd "${PWD}/examples/${example_dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b7023098f..62d6350d6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,17 +1,45 @@ name: Build -on: +on: pull_request: types: [synchronize, labeled] jobs: - build: - name: Build and Test Colossal-AI + detect: + name: Detect kernel-related file change if: | github.event.pull_request.draft == false && github.base_ref == 'main' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + outputs: + changedFiles: ${{ steps.find-changed-files.outputs.changedFiles }} + anyChanged: ${{ steps.find-changed-files.outputs.any_changed }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + - name: Find the changed files + id: find-changed-files + uses: tj-actions/changed-files@v35 + with: + since_last_remote_commit: true + files: | + op_builder/** + colossalai/kernel/** + setup.py + - name: List changed files + run: | + for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do + echo "$file was changed" + done + + + build: + name: Build and Test Colossal-AI + needs: detect runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 @@ -23,27 +51,38 @@ jobs: repository: hpcaitech/TensorNVMe ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} path: TensorNVMe + - name: Install tensornvme run: | cd TensorNVMe conda install cmake pip install -r requirements.txt pip install -v . + - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Install Colossal-AI + + - name: Restore cache + if: needs.detect.outputs.anyChanged != 'true' run: | - [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - pip install -r requirements/requirements.txt - pip install -v -e . - cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ - cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/ + # -p flag is required to preserve the file timestamp to avoid ninja rebuild + [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + + - name: Install Colossal-AI + run: | + CUDA_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt + - name: Unit Testing run: | - PYTHONPATH=$PWD pytest tests + PYTHONPATH=$PWD pytest --cov=. --cov-report lcov tests env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + + - name: Store Cache + run: | + # -p flag is required to preserve the file timestamp to avoid ninja rebuild + cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ diff --git a/.github/workflows/build_gpu_8.yml b/.github/workflows/build_gpu_8.yml index 4d96390f2..be8337dd0 100644 --- a/.github/workflows/build_gpu_8.yml +++ b/.github/workflows/build_gpu_8.yml @@ -2,7 +2,7 @@ name: Build on 8 GPUs on: schedule: - # run at 00:00 of every Sunday + # run at 00:00 of every Sunday - cron: '0 0 * * *' workflow_dispatch: @@ -30,13 +30,11 @@ jobs: - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Install Colossal-AI + - name: Install Colossal-AI run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - pip install -r requirements/requirements.txt - pip install -v -e . + CUDA_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ - cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt - name: Unit Testing run: | @@ -45,4 +43,3 @@ jobs: env: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 - \ No newline at end of file diff --git a/.github/workflows/compatibility_test.yml b/.github/workflows/compatibility_test.yml index 7948eb20c..eadd07886 100644 --- a/.github/workflows/compatibility_test.yml +++ b/.github/workflows/compatibility_test.yml @@ -70,7 +70,7 @@ jobs: - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} - - name: Install Colossal-AI + - name: Install Colossal-AI run: | pip install -r requirements/requirements.txt pip install -v --no-cache-dir . diff --git a/.github/workflows/dispatch_example_check.yml b/.github/workflows/dispatch_example_check.yml new file mode 100644 index 000000000..e0333422f --- /dev/null +++ b/.github/workflows/dispatch_example_check.yml @@ -0,0 +1,63 @@ +name: Manual Test Example +on: + workflow_dispatch: + inputs: + example_directory: + type: string + description: example directory, separated by space. For example, language/gpt, images/vit. Simply input language or simply gpt does not work. + required: true + +jobs: + matrix_preparation: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + name: Check the examples user want + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: Set up matrix + id: set-matrix + env: + check_dir: ${{ inputs.example_directory }} + run: | + res=`python .github/workflows/scripts/example_checks/check_dispatch_inputs.py --fileNameList $check_dir` + if [ res == "failure" ];then + exit -1 + fi + dirs="[${check_dir}]" + echo "Testing examples in $dirs" + echo "matrix={\"directory\":$(echo "$dirs")}" >> $GITHUB_OUTPUT + + test_example: + if: | + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + name: Manually check example files + needs: manual_check_matrix_preparation + runs-on: [self-hosted, gpu] + strategy: + matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/examples-data:/data/ + timeout-minutes: 10 + steps: + - name: 📚 Checkout + uses: actions/checkout@v3 + - name: Install Colossal-AI + run: | + pip install -v . + - name: Test the example + run: | + dir=${{ matrix.directory }} + echo "Testing ${dir} now" + cd "${PWD}/examples/${dir}" + bash test_ci.sh + env: + NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/draft_github_release_post.yml b/.github/workflows/draft_github_release_post.yml index f970a9091..413714daf 100644 --- a/.github/workflows/draft_github_release_post.yml +++ b/.github/workflows/draft_github_release_post.yml @@ -20,7 +20,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 with: - python-version: '3.7.12' + python-version: '3.8.14' - name: generate draft id: generate_draft run: | @@ -42,4 +42,3 @@ jobs: body_path: ${{ steps.generate_draft.outputs.path }} draft: True prerelease: false - \ No newline at end of file diff --git a/.github/workflows/release_bdist.yml b/.github/workflows/release_bdist.yml index aeac3e327..c9c51df8d 100644 --- a/.github/workflows/release_bdist.yml +++ b/.github/workflows/release_bdist.yml @@ -64,9 +64,21 @@ jobs: - name: Copy scripts and checkout run: | cp -r ./.github/workflows/scripts/* ./ + + # link the cache diretories to current path + ln -s /github/home/conda_pkgs ./conda_pkgs ln -s /github/home/pip_wheels ./pip_wheels + + # set the conda package path + echo "pkgs_dirs:\n - $PWD/conda_pkgs" > ~/.condarc + + # set safe directory git config --global --add safe.directory /__w/ColossalAI/ColossalAI + + # check out git checkout $git_ref + + # get cub package for cuda 10.2 wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip unzip 1.8.0.zip env: diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 8e88ea311..c72d3fb33 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -18,23 +18,17 @@ jobs: with: fetch-depth: 0 - name: Build Docker + id: build run: | version=$(cat version.txt) - docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t hpcaitech/colossalai:$version ./docker + tag=hpcaitech/colossalai:$version + docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker + echo "tag=${tag}" >> $GITHUB_OUTPUT - name: Log in to Docker Hub uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - name: Extract metadata (tags, labels) for Docker - id: meta - uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 - with: - images: hpcaitech/colossalai - - name: Build and push Docker image - uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc - with: - context: . - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file + - name: Push Docker image + run: | + docker push ${{ steps.build.outputs.tag }} diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index 0ef942841..8aa48b8ed 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -1,74 +1,29 @@ -name: Release bdist wheel for Nightly versions +name: Publish Nightly Version to PyPI on: - schedule: - # run at 00:00 of every Sunday - - cron: '0 0 * * 6' workflow_dispatch: - -jobs: - matrix_preparation: - name: Prepare Container List - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} - steps: - - id: set-matrix - run: | - matrix="[\"hpcaitech/cuda-conda:11.3\", \"hpcaitech/cuda-conda:10.2\"]" - echo $matrix - echo "::set-output name=matrix::{\"container\":$(echo $matrix)}" + schedule: + - cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time - build: - name: Release bdist wheels - needs: matrix_preparation - if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) - runs-on: [self-hosted, gpu] - strategy: - fail-fast: false - matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} - container: - image: ${{ matrix.container }} - options: --gpus all --rm - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - # cub is for cuda 10.2 - - name: Copy scripts and checkout - run: | - cp -r ./.github/workflows/scripts/* ./ - ln -s /github/home/pip_wheels ./pip_wheels - wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip - unzip 1.8.0.zip - - name: Build bdist wheel - run: | - pip install beautifulsoup4 requests packaging - python ./build_colossalai_wheel.py --nightly - - name: 🚀 Deploy - uses: garygrossgarten/github-action-scp@release - with: - local: all_dist - remote: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} - host: ${{ secrets.PRIVATE_PYPI_HOST }} - username: ${{ secrets.PRIVATE_PYPI_USER }} - password: ${{ secrets.PRIVATE_PYPI_PASSWD }} - remove_old_build: - name: Remove old nightly build +jobs: + build-n-publish: + if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' + name: Build and publish Python 🐍 distributions 📦 to PyPI runs-on: ubuntu-latest - needs: build + timeout-minutes: 20 steps: - - name: executing remote ssh commands using password - uses: appleboy/ssh-action@master - env: - BUILD_DIR: ${{ secrets.PRIVATE_PYPI_NIGHTLY_DIR }} - with: - host: ${{ secrets.PRIVATE_PYPI_HOST }} - username: ${{ secrets.PRIVATE_PYPI_USER }} - password: ${{ secrets.PRIVATE_PYPI_PASSWD }} - envs: BUILD_DIR - script: | - cd $BUILD_DIR - find . -type f -mtime +0 -exec rm -f {} + - script_stop: true - + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.8.14' + + - run: NIGHTLY=1 python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + verbose: true diff --git a/.github/workflows/release.yml b/.github/workflows/release_pypi.yml similarity index 63% rename from .github/workflows/release.yml rename to .github/workflows/release_pypi.yml index ab83c7a43..7f3f63cf3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release_pypi.yml @@ -1,21 +1,29 @@ name: Publish to PyPI -on: workflow_dispatch +on: + workflow_dispatch: + pull_request: + paths: + - 'version.txt' + types: + - closed jobs: build-n-publish: - if: github.ref_name == 'main' && github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) + if: github.event_name == 'workflow_dispatch' || github.repository == 'hpcaitech/ColossalAI' && github.event.pull_request.merged == true && github.base_ref == 'main' name: Build and publish Python 🐍 distributions 📦 to PyPI runs-on: ubuntu-latest timeout-minutes: 20 steps: - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 with: - python-version: '3.7.12' + python-version: '3.8.14' + - run: python setup.py sdist build + # publish to PyPI if executed on the main branch - # publish to Test PyPI if executed on the develop branch - name: Publish package to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml deleted file mode 100644 index cae88edaa..000000000 --- a/.github/workflows/release_test.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Publish to Test PyPI - -on: workflow_dispatch - -jobs: - build-n-publish: - if: github.repository == 'hpcaitech/ColossalAI' && contains(fromJson('["FrankLeeeee", "ver217", "feifeibear", "kurisusnowdeng"]'), github.actor) - name: Build and publish Python 🐍 distributions 📦 to Test PyPI - runs-on: ubuntu-latest - timeout-minutes: 20 - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: '3.7.12' - - run: python setup.py sdist build - # publish to PyPI if executed on the main branch - # publish to Test PyPI if executed on the develop branch - - name: Publish package to Test PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.TEST_PYPI_API_TOKEN }} - repository_url: https://test.pypi.org/legacy/ - verbose: true diff --git a/.github/workflows/scripts/build_colossalai_wheel.py b/.github/workflows/scripts/build_colossalai_wheel.py index 5a2db0c87..a9ac16fbc 100644 --- a/.github/workflows/scripts/build_colossalai_wheel.py +++ b/.github/workflows/scripts/build_colossalai_wheel.py @@ -1,12 +1,13 @@ -from filecmp import cmp -import requests -from bs4 import BeautifulSoup import argparse import os import subprocess -from packaging import version +from filecmp import cmp from functools import cmp_to_key +import requests +from bs4 import BeautifulSoup +from packaging import version + WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels' RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels' CUDA_HOME = os.environ['CUDA_HOME'] diff --git a/.github/workflows/scripts/build_colossalai_wheel.sh b/.github/workflows/scripts/build_colossalai_wheel.sh index 55a87d956..c0d40fd2c 100644 --- a/.github/workflows/scripts/build_colossalai_wheel.sh +++ b/.github/workflows/scripts/build_colossalai_wheel.sh @@ -18,7 +18,7 @@ if [ $1 == "pip" ] then wget -nc -q -O ./pip_wheels/$filename $url pip install ./pip_wheels/$filename - + elif [ $1 == 'conda' ] then conda install pytorch==$torch_version cudatoolkit=$cuda_version $flags @@ -34,8 +34,9 @@ fi python setup.py bdist_wheel mv ./dist/* ./all_dist +# must remove build to enable compilation for +# cuda extension in the next build +rm -rf ./build python setup.py clean conda deactivate conda env remove -n $python_version - - diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py new file mode 100644 index 000000000..04d2063ec --- /dev/null +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -0,0 +1,27 @@ +import argparse +import os + + +def check_inputs(input_list): + for path in input_list: + real_path = os.path.join('examples', path) + if not os.path.exists(real_path): + return False + return True + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + args = parser.parse_args() + name_list = args.fileNameList.split(",") + is_correct = check_inputs(name_list) + + if is_correct: + print('success') + else: + print('failure') + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py new file mode 100644 index 000000000..941e90901 --- /dev/null +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -0,0 +1,37 @@ +import os + + +def show_files(path, all_files): + # Traverse all the folder/file in current directory + file_list = os.listdir(path) + # Determine the element is folder or file. If file, pass it into list, if folder, recurse. + for file_name in file_list: + # Get the abs directory using os.path.join() and store into cur_path. + cur_path = os.path.join(path, file_name) + # Determine whether folder + if os.path.isdir(cur_path): + show_files(cur_path, all_files) + else: + all_files.append(cur_path) + return all_files + + +def join(input_list, sep=None): + return (sep or ' ').join(input_list) + + +def main(): + contents = show_files('examples/', []) + all_loc = [] + for file_loc in contents: + split_loc = file_loc.split('/') + # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. + if len(split_loc) >= 4: + re_loc = '/'.join(split_loc[1:3]) + if re_loc not in all_loc: + all_loc.append(re_loc) + print(all_loc) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py new file mode 100644 index 000000000..df4fd6736 --- /dev/null +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -0,0 +1,24 @@ +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + args = parser.parse_args() + name_list = args.fileNameList.split(":") + folder_need_check = set() + for loc in name_list: + # Find only the sub-sub-folder of 'example' folder + # the examples folder structure is like + # - examples + # - area + # - application + # - file + if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: + folder_need_check.add('/'.join(loc.split("/")[1:3])) + # Output the result using print. Then the shell can get the values. + print(list(folder_need_check)) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py index fdcd667ae..1c407cf14 100644 --- a/.github/workflows/scripts/generate_release_draft.py +++ b/.github/workflows/scripts/generate_release_draft.py @@ -2,9 +2,10 @@ # coding: utf-8 import argparse -import requests -import re import os +import re + +import requests COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' diff --git a/.github/workflows/submodule.yml b/.github/workflows/submodule.yml index ac01f85db..4ffb26118 100644 --- a/.github/workflows/submodule.yml +++ b/.github/workflows/submodule.yml @@ -1,6 +1,6 @@ name: Synchronize Submodule -on: +on: workflow_dispatch: schedule: - cron: "0 0 * * *" @@ -27,11 +27,11 @@ jobs: - name: Commit update run: | - git config --global user.name 'github-actions' - git config --global user.email 'github-actions@github.com' + git config --global user.name 'github-actions' + git config --global user.email 'github-actions@github.com' git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} git commit -am "Automated submodule synchronization" - + - name: Create Pull Request uses: peter-evans/create-pull-request@v3 with: @@ -43,4 +43,3 @@ jobs: assignees: ${{ github.actor }} delete-branch: true branch: create-pull-request/patch-sync-submodule - \ No newline at end of file diff --git a/.gitignore b/.gitignore index 458f37553..8e345eeb8 100644 --- a/.gitignore +++ b/.gitignore @@ -134,10 +134,23 @@ dmypy.json .vscode/ # macos -.DS_Store +*.DS_Store #data/ docs/.build # pytorch checkpoint -*.pt \ No newline at end of file +*.pt + +# ignore version.py generated by setup.py +colossalai/version.py + +# ignore any kernel build files +.o +.so + +# ignore python interface defition file +.pyi + +# ignore coverage test file +converage.lcov diff --git a/.readthedocs.yaml b/.readthedocs.yaml index ce22f43c1..98dd0cc4e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -27,4 +27,4 @@ sphinx: python: install: - requirements: requirements/requirements.txt - - requirements: docs/requirements.txt \ No newline at end of file + - requirements: docs/requirements.txt diff --git a/LICENSE b/LICENSE index 9ca515ca7..0528c89ea 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2021- The Colossal-ai Authors. All rights reserved. +Copyright 2021- HPC-AI Technology Inc. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -187,7 +187,7 @@ Copyright 2021- The Colossal-ai Authors. All rights reserved. same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2021- HPC-AI Technology Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in index 48a44e0b4..ad26b634a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include *.txt README.md recursive-include requirements *.txt -recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc \ No newline at end of file +recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi +recursive-include op_builder *.py diff --git a/README-zh-Hans.md b/README-zh-Hans.md index b678af55d..b97b02f5a 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -1,14 +1,14 @@ # Colossal-AI
- [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/) + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) Colossal-AI: 一个面向大模型时代的通用深度学习系统 -

论文 | - 文档 | - 例程 | - 论坛 | +

论文 | + 文档 | + 例程 | + 论坛 | 博客

[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) @@ -22,41 +22,50 @@
+## 新闻 +* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0) +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) +* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) + ## 目录 @@ -98,6 +102,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 1维, [2维](https://arxiv.org/abs/2104.05343), [2.5维](https://arxiv.org/abs/2105.14500), [3维](https://arxiv.org/abs/2105.14450) 张量并行 - [序列并行](https://arxiv.org/abs/2105.13120) - [零冗余优化器 (ZeRO)](https://arxiv.org/abs/1910.02054) + - [自动并行](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt) - 异构内存管理 - [PatrickStar](https://arxiv.org/abs/2108.05818) - 使用友好 @@ -105,16 +110,11 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 推理 - [Energon-AI](https://github.com/hpcaitech/EnergonAI) - Colossal-AI 成功案例 - - [xTrimoMultimer: 蛋白质单体与复合物结构预测](https://github.com/biomap-research/xTrimoMultimer) + - 生物医药: [FastFold](https://github.com/hpcaitech/FastFold) 加速蛋白质结构预测 AlphaFold 训练与推理

(返回顶端)

## 并行训练样例展示 -### ViT -

- -

-- 14倍批大小和5倍训练速度(张量并行=64) ### GPT-3

@@ -131,7 +131,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 用相同的硬件训练24倍大的模型 -- 超3倍的吞吐量 +- 超3倍的吞吐量 ### BERT @@ -145,10 +145,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), 由Meta发布的1750亿语言模型,由于完全公开了预训练参数权重,因此促进了下游任务和应用部署的发展。 -- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt) +- 加速45%,仅用几行代码以低成本微调OPT。[[样例]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[在线推理]](https://service.colossalai.org/opt) 请访问我们的 [文档](https://www.colossalai.org/) 和 [例程](https://github.com/hpcaitech/ColossalAI-Examples) 以了解详情。 +### ViT +

+ +

+ +- 14倍批大小和5倍训练速度(张量并行=64) ### 推荐系统模型 - [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), 使用软件Cache实现Embeddings,用更少GPU显存训练更大的模型。 @@ -178,7 +184,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - 用相同的硬件训练34倍大的模型 -

(back to top)

+

(返回顶端)

## 推理 (Energon-AI) 样例展示 @@ -195,23 +201,82 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - [OPT推理服务](https://service.colossalai.org/opt): 无需注册,免费体验1750亿参数OPT在线推理服务 +

+ +

-

(back to top)

+- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): 降低1750亿参数BLOOM模型部署推理成本超10倍 + +

(返回顶端)

## Colossal-AI 成功案例 -### xTrimoMultimer: 蛋白质单体与复合物结构预测 +### AIGC +加速AIGC(AI内容生成)模型,如[Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) 和 [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion) + +

+ +

+ +- [训练](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): 减少5.6倍显存消耗,硬件成本最高降低46倍(从A100到RTX3060) + +

+ +

+ +- [DreamBooth微调](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): 仅需3-5张目标主题图像个性化微调 + +

+ +

+ +- [推理](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): GPU推理显存消耗降低2.5倍 + + +

(返回顶端)

+ +### 生物医药 + +加速 [AlphaFold](https://alphafold.ebi.ac.uk/) 蛋白质结构预测 + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): 加速AlphaFold训练与推理、数据前处理、推理序列长度超过10000残基 +

- -

- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): 11倍加速蛋白质单体与复合物结构预测 +

(返回顶端)

## 安装 +### 从PyPI安装 + +您可以用下面的命令直接从PyPI上下载并安装Colossal-AI。我们默认不会安装PyTorch扩展包 + +```bash +pip install colossalai +``` + +但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`CUDA_EXT=1`. + +```bash +CUDA_EXT=1 pip install colossalai +``` + +**否则,PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。** + +与此同时,我们也每周定时发布Nightly版本,这能让你提前体验到新的feature和bug fix。你可以通过以下命令安装Nightly版本。 + +```bash +pip install colossalai-nightly +``` + ### 从官方安装 您可以访问我们[下载](https://www.colossalai.org/download)页面来安装Colossal-AI,在这个页面上发布的版本都预编译了CUDA扩展。 @@ -231,10 +296,10 @@ pip install -r requirements/requirements.txt pip install . ``` -如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装): +我们默认在`pip install`时不安装PyTorch扩展,而是在运行时临时编译,如果你想要提前安装这些扩展的话(在使用融合优化器时会用到),可以使用一下命令。 ```shell -NO_CUDA_EXT=1 pip install . +CUDA_EXT=1 pip install . ```

(返回顶端)

@@ -283,31 +348,6 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash

(返回顶端)

-## 快速预览 - -### 几行代码开启分布式训练 - -```python -parallel = dict( - pipeline=2, - tensor=dict(mode='2.5d', depth = 1, size=4) -) -``` - -### 几行代码开启异构训练 - -```python -zero = dict( - model_config=dict( - tensor_placement_policy='auto', - shard_strategy=TensorShardStrategy(), - reuse_fp16_shard=True - ), - optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2) -) -``` - -

(返回顶端)

## 引用我们 @@ -320,4 +360,4 @@ zero = dict( } ``` -

(返回顶端)

\ No newline at end of file +

(返回顶端)

diff --git a/README.md b/README.md index c5a798a0e..7aba907e0 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # Colossal-AI
- [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/Colossal-AI_logo.png)](https://www.colossalai.org/) + [![logo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/colossal-ai_logo_vertical.png)](https://www.colossalai.org/) Colossal-AI: A Unified Deep Learning System for Big Model Era -

Paper | - Documentation | - Examples | - Forum | +

Paper | + Documentation | + Examples | + Forum | Blog

[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml) @@ -17,46 +17,55 @@ [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) - + | [English](README.md) | [中文](README-zh-Hans.md) |
+## Latest News +* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0) +* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) +* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) +* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for) +* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) + ## Table of Contents @@ -100,8 +104,9 @@ distributed training and inference in a few lines. - 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism - [Sequence Parallelism](https://arxiv.org/abs/2105.13120) - [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054) + - [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt) -- Heterogeneous Memory Management +- Heterogeneous Memory Management - [PatrickStar](https://arxiv.org/abs/2108.05818) - Friendly Usage @@ -110,17 +115,11 @@ distributed training and inference in a few lines. - Inference - [Energon-AI](https://github.com/hpcaitech/EnergonAI) -- Colossal-AI in the Real World - - [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): Accelerating Protein Monomer and Multimer Structure Prediction +- Colossal-AI in the Real World + - Biomedicine: [FastFold](https://github.com/hpcaitech/FastFold) accelerates training and inference of AlphaFold protein structure

(back to top)

## Parallel Training Demo -### ViT -

- -

- -- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64 ### GPT-3

@@ -150,10 +149,17 @@ distributed training and inference in a few lines. - [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights. -- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt) +- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt) Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details. +### ViT +

+ +

+ +- 14x larger batch size, and 5x faster training for Tensor Parallelism = 64 + ### Recommendation System Models - [Cached Embedding](https://github.com/hpcaitech/CachedEmbedding), utilize software cache to train larger embedding tables with a smaller GPU memory budget. @@ -198,26 +204,85 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt - [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever. +

+ +

+ +- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 175-billion-parameter BLOOM by more than 10 times. +

(back to top)

## Colossal-AI in the Real World -### xTrimoMultimer: Accelerating Protein Monomer and Multimer Structure Prediction +### AIGC +Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). +

+ +

+ +- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060). + +

+ +

+ +- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject. + +

+ +

+ +- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x. + + +

(back to top)

+ +### Biomedicine +Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/) + +

+ +

+ +- [FastFold](https://github.com/hpcaitech/FastFold): accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues. +

- -

-- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x +- [xTrimoMultimer](https://github.com/biomap-research/xTrimoMultimer): accelerating structure prediction of protein monomers and multimer by 11x. +

(back to top)

## Installation +### Install from PyPI + +You can easily install Colossal-AI with the following command. **By defualt, we do not build PyTorch extensions during installation.** + +```bash +pip install colossalai +``` + +However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`. + +```bash +CUDA_EXT=1 pip install colossalai +``` + +**Otherwise, CUDA kernels will be built during runtime when you actually need it.** + +We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch. +Installation can be made via + +```bash +pip install colossalai-nightly +``` + ### Download From Official Releases -You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions. +You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built PyTorch extensions. ### Download From Source @@ -228,17 +293,15 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo git clone https://github.com/hpcaitech/ColossalAI.git cd ColossalAI -# install dependency -pip install -r requirements/requirements.txt - # install colossalai pip install . ``` -If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): +By default, we do not compile CUDA/C++ kernels. ColossalAI will build them during runtime. +If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): ```shell -NO_CUDA_EXT=1 pip install . +CUDA_EXT=1 pip install . ```

(back to top)

@@ -289,32 +352,6 @@ Thanks so much to all of our amazing contributors!

(back to top)

-## Quick View - -### Start Distributed Training in Lines - -```python -parallel = dict( - pipeline=2, - tensor=dict(mode='2.5d', depth = 1, size=4) -) -``` - -### Start Heterogeneous Training in Lines - -```python -zero = dict( - model_config=dict( - tensor_placement_policy='auto', - shard_strategy=TensorShardStrategy(), - reuse_fp16_shard=True - ), - optimizer_config=dict(initial_scale=2**5, gpu_margin_mem_ratio=0.2) -) - -``` - -

(back to top)

## Cite Us diff --git a/colossalai/_C/__init__.py b/colossalai/_C/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/__init__.py b/colossalai/__init__.py index fab03445b..f859161f7 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -7,4 +7,11 @@ from .initialize import ( launch_from_torch, ) -__version__ = '0.1.11rc1' +try: + # .version will be created by setup.py + from .version import __version__ +except ModuleNotFoundError: + # this will only happen if the user did not run `pip install` + # and directly set PYTHONPATH to use Colossal-AI which is a bad practice + __version__ = '0.0.0' + print('please install Colossal-AI from https://www.colossalai.org/download or from source') diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py index 6689a157c..51b9b97dc 100644 --- a/colossalai/amp/apex_amp/__init__.py +++ b/colossalai/amp/apex_amp/__init__.py @@ -1,7 +1,8 @@ -from .apex_amp import ApexAMPOptimizer import torch.nn as nn from torch.optim import Optimizer +from .apex_amp import ApexAMPOptimizer + def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): r"""A helper function to wrap training components with Apex AMP modules diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index bb2b8eb26..5b2f71d3c 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -1,10 +1,13 @@ import inspect + import torch.nn as nn from torch.optim import Optimizer + from colossalai.utils import is_no_pp_or_last_stage -from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel -from .grad_scaler import DynamicGradScaler, ConstantGradScaler + from ._fp16_optimizer import FP16Optimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 58d9e3df1..e4699f92b 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -3,24 +3,33 @@ import torch import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.optim import Optimizer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.logging import get_dist_logger +from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier + +from ._utils import has_inf_or_nan, zero_gard_by_list +from .grad_scaler import BaseGradScaler try: - import colossal_C + from colossalai._C import fused_optim except: - print('Colossalai should be built with cuda extension to use the FP16 optimizer') - -from torch.optim import Optimizer -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.logging import get_dist_logger -from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier) -from torch.distributed import ProcessGroup -from .grad_scaler import BaseGradScaler -from ._utils import has_inf_or_nan, zero_gard_by_list + fused_optim = None __all__ = ['FP16Optimizer'] +def load_fused_optim(): + global fused_optim + + if fused_optim is None: + fused_optim = FusedOptimBuilder().load() + + def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): """ adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) @@ -33,7 +42,9 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) + global fused_optim + load_fused_optim() + multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): that_.copy_(this_) @@ -41,7 +52,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): class FP16Optimizer(Optimizer): """Float16 optimizer for fp16 and bf16 data types. - + Args: optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD grad_scaler (BaseGradScaler): grad scaler for gradient chose in @@ -73,8 +84,8 @@ class FP16Optimizer(Optimizer): # get process group def _get_process_group(parallel_mode): - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA): - return gpc.get_group(ParallelMode.DATA) + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode): + return gpc.get_group(parallel_mode) else: return None @@ -150,6 +161,12 @@ class FP16Optimizer(Optimizer): f"==========================================", ranks=[0]) + @property + def max_norm(self): + """Returns the maximum norm of gradient clipping. + """ + return self._clip_grad_max_norm + @property def grad_scaler(self): """Returns the gradient scaler. diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py index ad2a2ceed..7633705e1 100644 --- a/colossalai/amp/naive_amp/_utils.py +++ b/colossalai/amp/naive_amp/_utils.py @@ -1,4 +1,5 @@ from typing import List + from torch import Tensor diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index d27883a8e..0d84384a7 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -1,12 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch from abc import ABC, abstractmethod -from colossalai.logging import get_dist_logger -from torch import Tensor from typing import Dict +import torch +from torch import Tensor + +from colossalai.logging import get_dist_logger + __all__ = ['BaseGradScaler'] diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 1ac26ee91..6d6f2f287 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -1,10 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch -from .base_grad_scaler import BaseGradScaler from typing import Optional +import torch + +from .base_grad_scaler import BaseGradScaler + __all__ = ['DynamicGradScaler'] diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index 02eae80b9..6a39d518d 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -1,17 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch -import torch.nn as nn -import torch.distributed as dist -from torch import Tensor from typing import Any -from torch.optim import Optimizer -from torch.distributed import ReduceOp -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.nn.optimizer import ColossalaiOptimizer + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import ReduceOp +from torch.optim import Optimizer + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.optimizer import ColossalaiOptimizer + from ._fp16_optimizer import FP16Optimizer @@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): return self.optim.step() def clip_grad_norm(self, model: nn.Module, max_norm: float): - pass + if self.optim.max_norm == max_norm: + return + raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " + "If you have supplied clip_grad_norm in the amp_config, " + "executing the method clip_grad_norm is not allowed.") class NaiveAMPModel(nn.Module): diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py index 8943b86d6..893cc890d 100644 --- a/colossalai/amp/torch_amp/__init__.py +++ b/colossalai/amp/torch_amp/__init__.py @@ -1,10 +1,13 @@ -import torch.nn as nn -from torch.optim import Optimizer -from torch.nn.modules.loss import _Loss -from colossalai.context import Config -from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss from typing import Optional +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer + def convert_to_torch_amp(model: nn.Module, optimizer: Optimizer, diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py index de39b3e16..7b78998fb 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -3,16 +3,18 @@ # modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py # to support tensor parallel -import torch -from collections import defaultdict, abc import warnings +from collections import abc, defaultdict from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from colossalai.context import ParallelMode + +import torch import torch.distributed as dist -from colossalai.core import global_context as gpc -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from packaging import version +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc class _MultiDeviceReplicator(object): diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/amp/torch_amp/torch_amp.py index 5074e9c81..65718d77c 100644 --- a/colossalai/amp/torch_amp/torch_amp.py +++ b/colossalai/amp/torch_amp/torch_amp.py @@ -1,17 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch.nn as nn import torch.cuda.amp as torch_amp - +import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from ._grad_scaler import GradScaler from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import clip_grad_norm_fp32 +from ._grad_scaler import GradScaler + class TorchAMPOptimizer(ColossalaiOptimizer): """A wrapper class which integrate Pytorch AMP with an optimizer diff --git a/colossalai/auto_parallel/checkpoint/__init__.py b/colossalai/auto_parallel/checkpoint/__init__.py index e69de29bb..10ade417a 100644 --- a/colossalai/auto_parallel/checkpoint/__init__.py +++ b/colossalai/auto_parallel/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .ckpt_solver_base import CheckpointSolverBase +from .ckpt_solver_chen import CheckpointSolverChen +from .ckpt_solver_rotor import CheckpointSolverRotor diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py new file mode 100644 index 000000000..af4349865 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -0,0 +1,16 @@ +import os + +from setuptools import Extension, setup + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'rotorc', + sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py new file mode 100644 index 000000000..b388d00ac --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -0,0 +1,195 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, List + +import torch +from torch.fx import Graph, Node + +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) +from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + +__all___ = ['CheckpointSolverBase'] + + +def _copy_output(src: Graph, dst: Graph): + """Copy the output node from src to dst""" + for n_src, n_dst in zip(src.nodes, dst.nodes): + if n_src.op == 'output': + n_dst.meta = n_src.meta + + +def _get_param_size(module: torch.nn.Module): + """Get the size of the parameters in the module""" + return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()]) + + +class CheckpointSolverBase(ABC): + + def __init__( + self, + graph: Graph, + free_memory: float = -1.0, + requires_linearize: bool = False, + cnode: List[str] = None, + optim_multiplier: float = 1.0, + ): + """``CheckpointSolverBase`` class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for target + computing graph. + + Existing Solvers: + Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) + Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor) + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float): Memory constraint for the solution. + requires_linearize (bool): Whether the graph needs to be linearized. + cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. + + Warnings: + Meta information of the graph is required for any ``CheckpointSolver``. + """ + # super-dainiu: this graph is a temporary graph which can refer to + # the owning module, but we will return another deepcopy of it after + # the solver is executed. + self.graph = deepcopy(graph) + self.graph.owning_module = graph.owning_module + _copy_output(graph, self.graph) + self.graph.set_codegen(ActivationCheckpointCodeGen()) + + # check if has meta information + if any(len(node.meta) == 0 for node in self.graph.nodes): + raise RuntimeError( + "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" + ) + + # parameter memory = parameter size + optimizer extra weight storage + self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) + self.cnode = cnode + self.requires_linearize = requires_linearize + if self.requires_linearize: + self.node_list = self._linearize_graph() + else: + self.node_list = self.get_node_list() + + @abstractmethod + def solve(self): + """Solve the checkpointing problem and return the solution. + """ + pass + + def get_node_list(self): + """Get the node list. + """ + return [[node] for node in self.graph.nodes] + + def _linearize_graph(self) -> List[List[Node]]: + """Linearizing the graph + + Args: + graph (Graph): The computing graph to be optimized. + + Returns: + List[List[Node]]: List of list, each inside list of Node presents + the actual 'node' in linearized manner. + + Remarks: + Do merge the inplace ops and shape-consistency ops into the previous node. + """ + + # Common nodes are type of nodes that could be seen as attributes and remain + # unchanged throughout the whole model, it will be used several times by + # different blocks of model, so that it is hard for us to linearize the graph + # when we encounter those kinds of nodes. We let users to annotate some of the + # input as common node, such as attention mask, and the followings are some of + # the ops that could actually be seen as common nodes. With our common node prop, + # we could find some of the "real" common nodes (e.g. the real attention mask + # used in BERT and GPT), the rule is simple, for node who's parents are all common + # nodes or it's op belongs to the following operations, we view this node as a + # newly born common node. + # List of target name that could be seen as common node + common_ops = ["getattr", "getitem", "size"] + + def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in common_ops + else: + return target.__name__ in common_ops + + def _is_sink() -> bool: + """Check if we can free all dependencies + + Returns: + bool + """ + + def _is_inplace(n: Node): + """Get the inplace argument from ``torch.fx.Node`` + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) + + # make sure that item in cnode is valid + if self.cnode: + for name in self.cnode: + try: + assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ + f"Common node {name} is not an input of the model." + except StopIteration: + raise ValueError(f"Common node name {name} not in graph.") + + else: + self.cnode = [] + + deps = {} + node_list = [] + region = [] + + for n in self.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n.all_input_nodes: + if n_par.op != "placeholder" and n_par.name not in self.cnode: + deps[n_par] -= 1 + region.append(n) + + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + node_list.append(region) + region = [] + + # propagate common node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode + ]) or _is_cop(n.target): + self.cnode.append(n.name) + else: + deps[n] = len([user for user in n.users if user.op != "output"]) + return node_list diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py new file mode 100644 index 000000000..19b2ef598 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -0,0 +1,87 @@ +import math +from copy import deepcopy +from typing import List, Set, Tuple + +from torch.fx import Graph, Node + +from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp + +from .ckpt_solver_base import CheckpointSolverBase + +__all__ = ['CheckpointSolverChen'] + + +class CheckpointSolverChen(CheckpointSolverBase): + + def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. + + Usage: + Assume that we have a ``GraphModule``, and we have already done the extractions + to the graph to retrieve all information needed, then we could use the following + code to find a solution using ``CheckpointSolverChen``: + >>> solver = CheckpointSolverChen(gm.graph) + >>> chen_graph = solver.solve() + >>> gm.graph = chen_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + num_grids (int, optional): Number of grids to search for b. Defaults to 6. + """ + super().__init__(graph, 0, 0, True, cnode) + self.num_grids = num_grids + + def solve(self) -> Graph: + """Solve the checkpointing problem using Algorithm 3. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + ckpt = self.grid_search() + for i, seg in enumerate(ckpt): + for idx in range(*seg): + nodes = self.node_list[idx] + for n in nodes: + if n.op in checkpointable_op: + n.meta['activation_checkpoint'] = i + return deepcopy(self.graph) + + def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt_intv = [] + temp = 0 + x = 0 + y = 0 + prev_idx = 2 + for idx, nodes in enumerate(self.node_list): + for n in nodes: + n: Node + temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) + y = max(y, temp) + if temp > b and idx > prev_idx: + x += calculate_fwd_in(nodes[0]) + temp = 0 + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) + + def grid_search(self) -> Set: + """ + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. + Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. + """ + _, b_approx = self.run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // self.num_grids): + ckpt_intv, b_approx = self.run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx + ckpt_opt = ckpt_intv + return ckpt_opt diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c new file mode 100644 index 000000000..0fdcfd58a --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c @@ -0,0 +1,197 @@ +#define PY_SSIZE_T_CLEAN +#include + +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* computeTable(PyObject* self, PyObject* args) { + PyObject* chainParam; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL; + + double* ftime = getDoubleArray(chainParam, "ftime"); + if (!ftime) return NULL; + + double* btime = getDoubleArray(chainParam, "btime"); + if (!btime) return NULL; + + long* x = getLongArray(chainParam, "x"); + if (!x) return NULL; + + long* xbar = getLongArray(chainParam, "xbar"); + if (!xbar) return NULL; + + long* ftmp = getLongArray(chainParam, "btmp"); + if (!ftmp) return NULL; + + long* btmp = getLongArray(chainParam, "btmp"); + if (!btmp) return NULL; + + long chainLength = PyObject_Length(chainParam); + if (!chainLength) return NULL; + +#define COST_TABLE(m, i, l) \ + costTable[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + double* costTable = (double*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double)); + +#define BACK_PTR(m, i, l) \ + backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + long* backPtr = (long*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chainLength; ++i) + if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && + (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) + COST_TABLE(m, i, i) = ftime[i] + btime[i]; + else + COST_TABLE(m, i, i) = INFINITY; + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chainLength; ++d) { + for (long i = 0; i <= chainLength - d; ++i) { + long idx = i + d; + long mmin = x[idx + 1] + x[i + 1] + ftmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]); + } + mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + for (long j = i + 1; j <= idx; ++j) { + sumFw += ftime[j - 1]; + if (m >= x[j]) { + double cost = sumFw + COST_TABLE(m - x[j], j, idx) + + COST_TABLE(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= xbar[i + 1]) + chainCost = + COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); + if (bestLeafCost <= chainCost) { + COST_TABLE(m, i, idx) = bestLeafCost; + BACK_PTR(m, i, idx) = bestLeaf; + } else { + COST_TABLE(m, i, idx) = chainCost; + BACK_PTR(m, i, idx) = -1; + } + } else + COST_TABLE(m, i, idx) = INFINITY; + } + } + + free(ftime); + free(btime); + free(x); + free(xbar); + free(ftmp); + free(btmp); + + PyObject* pyCostTable = PyList_New(mmax + 1); + PyObject* pyBackPtr = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* pyCostTable_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyCostTable, m, pyCostTable_m); + PyObject* pyBackPtr_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m); + for (long i = 0; i <= chainLength; ++i) { + PyObject* pyCostTable_m_i = PyDict_New(); + PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i); + PyObject* pyBackPtr_m_i = PyDict_New(); + PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i); + for (long l = i; l <= chainLength; ++l) { + PyObject* pyVar_l = PyLong_FromLong(l); + PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l)); + PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); + Py_DECREF(pyCostTable_m_i_l); + PyObject* pyBackPtr_m_i_l; + if (BACK_PTR(m, i, l) < 0) + pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); + else + pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); + PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); + Py_DECREF(pyBackPtr_m_i_l); + Py_DECREF(pyVar_l); + } + } + } + + free(costTable); + free(backPtr); + + PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr); + Py_DECREF(pyCostTable); + Py_DECREF(pyBackPtr); + return result; +} + +static PyMethodDef rotorMethods[] = { + {"compute_table", computeTable, METH_VARARGS, + "Compute the optimal table with the rotor algorithm."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef rotorModule = { + PyModuleDef_HEAD_INIT, "rotorc", /* name of module */ + "A simple implementation of dynamic programming algorithm rotor with C in " + "https://hal.inria.fr/hal-02352969. Some code are adapted from " + "https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be + NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + rotorMethods}; + +PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); } diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py new file mode 100644 index 000000000..41d23be5c --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -0,0 +1,441 @@ +from copy import deepcopy +from typing import Any, Dict, List, Tuple + +from torch import Tensor +from torch.fx import Graph, Node + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.profiler import ( + activation_size, + calculate_bwd_time, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, +) +from colossalai.logging import get_dist_logger + +from .ckpt_solver_base import CheckpointSolverBase +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence + +__all__ = ['CheckpointSolverRotor'] + + +class CheckpointSolverRotor(CheckpointSolverBase): + + def __init__(self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0): + """This is the simple implementation of dynamic programming algorithm rotor + in https://hal.inria.fr/hal-02352969. Some code are adapted from + https://gitlab.inria.fr/hiepacs/rotor. + + Usage: + Assume that we have a ``GraphModule``, and we have already done the extractions + to the graph to retrieve all information needed, then we could use the following + code to find a solution using ``CheckpointSolverRotor``: + >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) + >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver + >>> gm.graph = rotor_graph # set the graph to a new graph + + Args: + graph (Graph): The computing graph to be optimized. + free_memory (float, optional): Memory constraint for the solution, unit is byte. + Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. + cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. + memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. + """ + super().__init__(graph, free_memory, True, cnode, optim_multiplier) + self.memory_slots = memory_slots + + # construct chain + unit = self.free_memory // self.memory_slots + self.chain = self._construct_chain(self.graph, self.node_list) + self.chain.discretize_all(unit) + + self.cost_table = None + self.back_ptr = None + self.sequence = None + + def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: + """Solve the checkpointing problem using rotor algorithm. + + Args: + force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. + verbose (bool, optional): Print verbose information. Defaults to False. + + Returns: + graph (Graph): The optimized graph, should be a copy of the original graph. + """ + chain = self.chain + + # compute cost table + if force_python: + self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots) + else: + self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) + + if verbose: + self.print_chain() + + # backtrack + try: + self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, + self.back_ptr) + self._annotate_from_sequence(self.sequence, self.node_list) + except ValueError as e: + # using logger to annonce that the solver is failed + logger = get_dist_logger() + logger.warning(f'Checkpoint solver failed: {e}') + raise ValueError + + if verbose: + self.print_sequence() + + return deepcopy(self.graph) + + def print_chain(self): + print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + for idx in range(len(self.node_list) - 1): + print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], + self.chain.btmp[idx]) + print(f'Chain = {self.chain}') + + def print_sequence(self): + print(f'Sequence = {self.sequence}') + + @classmethod + def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: + input_tensors = cls._extract_input(graph) + ftime, btime, ftmp, btmp = list(), list(), list(), list() + xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] + + for node in node_list: + node_info = cls._extract_node_info(node) + ftime.append(node_info[0]) + btime.append(node_info[1]) + x.append(node_info[2]) + xbar.append(node_info[3]) + ftmp.append(node_info[4]) + btmp.append(node_info[5]) + + # currently we view loss backward temp as zero + btime.append(0) + btmp.append(0) + + return Chain(ftime, btime, x, xbar, ftmp, btmp) + + @classmethod + def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: + """Extract node info from a list of nodes""" + xbar = 0 + ftime = 0 + btime = 0 + fwd_mem_peak = 0 + for n in node: + assert isinstance(n, Node), f'{n} is not a Node' + if n.target == runtime_apply or n.target == runtime_comm_spec_apply: + # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta + xbar += n.meta['fwd_mem_out'] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + else: + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + + # minimum flop count is required + ftime += max(calculate_fwd_time(n), 1.0) + btime += max(calculate_bwd_time(n), 1.0) + + x = calculate_fwd_out(node[-1]) + xbar = max(x, xbar) + ftmp = fwd_mem_peak - xbar + btmp = cls._extract_btmp(node) + return ftime, btime, x, xbar, ftmp, btmp + + @staticmethod + def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: + """Extract input tensors from a Graph""" + input_tensors = [] + for node in graph.nodes: + if node.op == 'placeholder': + input_tensors.append(node.meta['fwd_out']) + return input_tensors + + @staticmethod + def _extract_unused_output(node: Node) -> int: + """Extract unused output from `torch.fx.Node`""" + return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) + + @staticmethod + def _extract_btmp(node: List[Node]) -> int: + """Extract btmp from a list of nodes""" + + def _extract_deps_size(): + deps_size = 0 + for k, v in deps.items(): + k: Node + if v > 0: + deps_size += k.meta['bwd_mem_out'] + if v == float('-inf'): + deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) + + return deps_size + + btmp = 0 + deps = {} + for n in reversed(node): + deps[n] = len(n.all_input_nodes) + btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + for child in n.users: + if child in deps: + deps[child] -= 1 + if deps[child] <= 0: + deps[child] = float('-inf') # free + return btmp + + @staticmethod + def _compute_table(chain: Chain, mmax: int) -> Tuple: + """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + mmax (int): Maximum number of memory slots. + + Returns: + cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length + and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax + back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice + is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint + of length j + """ + + ftime = chain.ftime + [0.0] + btime = chain.btime + x = chain.x + [0] + xbar = chain.xbar + [0] + ftmp = chain.ftmp + [0] + btmp = chain.btmp + [0] + + # Build table + cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + + # Initialize borders of the tables for lmax-lmin = 0 + for m in range(mmax + 1): + for i in range(len(chain) + 1): + limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) + if m >= limit: # Equation (1) + cost_table[m][i][i] = ftime[i] + btime[i] + else: + cost_table[m][i][i] = float("inf") + + # Compute everything + for m in range(mmax + 1): + for d in range(1, len(chain) + 1): + for i in range(len(chain) + 1 - d): + idx = i + d + mmin = x[idx + 1] + x[i + 1] + ftmp[i] + if idx > i + 1: + mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx))) + if m < mmin: + cost_table[m][i][idx] = float("inf") + else: + leaf_checkpoints = [(j, + sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= xbar[i + 1]: + chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + cost_table[m][i][idx] = best_leaf[1] + back_ptr[m][i][idx] = (False, best_leaf[0]) + else: + cost_table[m][i][idx] = chain_checkpoint + back_ptr[m][i][idx] = (True,) + return cost_table, back_ptr + + @staticmethod + def _compute_table_c(chain: Chain, mmax: int) -> Tuple: + try: + from .rotorc import compute_table + + # build module if module not found + except ModuleNotFoundError: + import os + import subprocess + import sys + logger = get_dist_logger() + logger.info("rotorc hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + [ + f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", + f"--build-lib={this_dir}" + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.wait() == 0: + logger.info("rotorc has been built!", ranks=[0]) + from .rotorc import compute_table + else: + logger.warning("rotorc built failed! Using python version!", ranks=[0]) + return CheckpointSolverRotor._compute_table(chain, mmax) + return compute_table(chain, mmax) + + @staticmethod + def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], + back_ptr: List[Any]) -> "Sequence": + """Backtrack the cost table and retrieve the optimal checkpointing strategy. + + Args: + chain (Chain): A basic linearized structure for solving the dynamic programming problem. + lhs (int): The left index of the interval to backtrack. + rhs (int): The right index of the interval to backtrack. + budget (int): The memory budget for processing this interval. + cost_table (List[Any]): See ``._compute_table()`` for definitions + back_ptr (List[Any]): See ``._compute_table()`` for definitions + + Raises: + ValueError: Can not process the chain. + + Returns: + sequence (Sequence): The sequence of executing nodes with checkpoints. + """ + if budget <= 0: + raise ValueError(f"Can not process a chain with negative memory {budget}") + elif cost_table[budget][lhs][rhs] == float("inf"): + raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}") + + sequence = Sequence() + if rhs == lhs: + if lhs == len(chain): + sequence += [Loss()] + else: + sequence += [ForwardEnable(lhs), Backward(lhs)] + return sequence + + if back_ptr[budget][lhs][rhs][0]: + sequence += [ + ForwardEnable(lhs), + CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, + back_ptr), + Backward(lhs), + ] + else: + best_leaf = back_ptr[budget][lhs][rhs][1] + sequence += [ForwardCheck(lhs)] + sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] + sequence += [ + CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, + back_ptr), + CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), + ] + return sequence + + @staticmethod + def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): + """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. + + Args: + sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. + node_list (List[List[Node]]): The list of nodes to annotate. + """ + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for idx, op in enumerate(fwd_list, 0): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'] = [ckpt_idx] + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.meta['activation_checkpoint'].append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max( + len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - + len(op_list[idx].meta['activation_checkpoint'])) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py new file mode 100644 index 000000000..ab0c6c5ad --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -0,0 +1,184 @@ +import math +from abc import ABC +from typing import Any, Iterable, List + +from torch.utils._pytree import tree_map + + +class Chain: + + def __init__(self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True): + """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. + See paper https://hal.inria.fr/hal-02352969 for details. + + Args: + ftime (List[float]): The forward time of each node. + btime (List[float]): The backward time of each node. + x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper. + xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper. + ftmp (List[int]): The temporary forward memory of each node. + btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget. + check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True. + """ + self.ftime = ftime + self.btime = btime + self.x = x + self.xbar = xbar + self.ftmp = ftmp + self.btmp = btmp + if check_consistency and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1)) + + def __repr__(self): + chain_list = [] + for i in range(len(self)): + chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i])) + i = len(self) + chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i])) + return chain_list.__repr__() + + def __len__(self): + return len(self.ftime) + + def discretize_all(self, unit: int): + """Discretize the chain into a list of chains according to unit size.""" + discretizer = lambda val: math.ceil(val / unit) + self.x = tree_map(discretizer, self.x) + self.xbar = tree_map(discretizer, self.xbar) + self.ftmp = tree_map(discretizer, self.ftmp) + self.btmp = tree_map(discretizer, self.btmp) + + +class Operation(ABC): + name = "Op" + + def __repr__(self) -> str: + return f"{self.name}_{self.index}" + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Forward(Operation): + name = "F" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.ftime[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + name = "Fe" + + +class ForwardNograd(Forward): + name = "Fn" + + +class ForwardCheck(Forward): + name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain: Chain): + if chain is not None: + return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + name = "B" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + if chain is not None: + return chain.btime[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + name = "MA" + + def __init__(self, index): + self.index = index + + def cost(self, chain: Chain): + return 0 + + +class WriteMemory(MemoryAccess): + name = "WM" + + +class ReadMemory(MemoryAccess): + name = "RM" + + +class DiscardMemory(MemoryAccess): + name = "DM" + + +class Sequence(list): + + def __init__(self): + super().__init__() + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + op_list = [] + for x in self: + if isinstance(x, Operation): + op_list.append(x) + else: + assert isinstance(x, Sequence) + op_list += x.list_operations() + return op_list diff --git a/colossalai/auto_parallel/meta_profiler/__init__.py b/colossalai/auto_parallel/meta_profiler/__init__.py new file mode 100644 index 000000000..bfd361951 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/__init__.py @@ -0,0 +1,3 @@ +from .meta_registry import * +from .metainfo import * +from .registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py new file mode 100644 index 000000000..35b8c13ee --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -0,0 +1,15 @@ +import operator + +import torch +import torch.nn as nn + +from ..tensor_shard.constants import * + +# list of inplace module +INPLACE_MODULE = [nn.ReLU] + +# list of inplace operations +INPLACE_OPS = [torch.flatten] + +# list of operations that do not save forward activations +NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py new file mode 100644 index 000000000..aa5f77f65 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -0,0 +1,6 @@ +from .activation import * +from .binary_elementwise_ops import * +from .conv import * +from .linear import * +from .norm import * +from .pooling import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py new file mode 100644 index 000000000..774457f7d --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -0,0 +1,74 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["relu_meta_info"] + + +@meta_register.register(torch.nn.ReLU) +def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.ReLU metainfo generator + The aten graph of torch.nn.ReLU is + graph(): + %input_2 : [#users=1] = placeholder[target=placeholder](default=) + %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) + %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = kwargs.get("inplace", False) + + # construct input args for forward + fwd_in_args = [input_tensor] + + # construct input args for backward + bwd_in_args = [output_tensor] + + # calculate cost + # the fwd op with compute cost is relu.default + # the bwd op with compute cost is threshold_backward + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: the inplace ReLU don't have forward memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), + parameter=0, + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + # NOTE: It might seems a little bit weird here, we just want to align it with the older version + # of MetaInfoProp. In the future we might modify this part to make it clearer. + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py new file mode 100644 index 000000000..281a92c0d --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -0,0 +1,66 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION +from ..registry import meta_register + +__all__ = ['binary_elementwise_meta_info'] + + +@meta_register.register(BCAST_FUNC_OP) +def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta information generator for binary elementwise operations + NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they + don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`, + they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify + this behavior, it is critical for better memory estimation. + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) + + # construct forward args for flop mapping + fwd_in_args = [opdata.data for opdata in input_op_data] + fwd_out_args = [output_op_data.data] + + # calculate cost + + # calculate compute cost + # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) + fwd_mem_cost = MemoryCost( + activation=activation_size(output_op_data.data), + parameter=param_mem_cost, + ) + bwd_mem_cost = MemoryCost( + activation=activation_size(fwd_in_args), + parameter=param_mem_cost, + ) + + # total cost + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + ) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py new file mode 100644 index 000000000..d1bb6e7fa --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -0,0 +1,137 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['convnd_meta_info'] + + +@meta_register.register(torch.nn.Conv1d) +@meta_register.register(torch.nn.Conv2d) +@meta_register.register(torch.nn.Conv3d) +@meta_register.register(torch.nn.functional.conv1d) +@meta_register.register(torch.nn.functional.conv2d) +@meta_register.register(torch.nn.functional.conv3d) +def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator + The atens graph of torch.nn.Convnd with bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The atens graph of torch.nn.Convnd without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] + + # check if conv has bias + if len(weight_tensors) > 1: + has_bias = True + # bias tensor's shape only has one dimension + if len(weight_tensors[0].shape) == 1: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor, bias_tensor = weight_tensors + + else: + weight_tensor = weight_tensors[0] + + # construct input args for forward + fwd_args = [None] * 9 + + # weight and input + fwd_args[0] = input_tensor + fwd_args[1] = weight_tensor + fwd_args[2] = bias_tensor if has_bias else None + + # transpose indicator should be set to False + fwd_args[6] = False + + # construct input args for backward + bwd_args = [None] * 11 + + # weight and input + bwd_args[0] = output_tensor + bwd_args[1] = input_tensor + bwd_args[2] = weight_tensor + bwd_args[-1] = [True, True, True] if has_bias else [True, True, False] + + # calculate cost + # the fwd op with compute cost is convolution.default + # the bwd op with compute cost is convolution_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # TODO: use profiler to check conv temp memory + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost( + activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost( + activation=activation_size([input_tensor, weight_tensor, bias_tensor]) + if has_bias else activation_size([input_tensor, weight_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), + temp=0, + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py new file mode 100644 index 000000000..61f8fdff3 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -0,0 +1,172 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['linear_meta_info'] + + +@meta_register.register(torch.nn.functional.linear) +@meta_register.register(torch.nn.Linear) +def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Linear & torch.nn.functional.linear meta info generator + NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator, + but we will hold the bias mechanism in the linear metainfo generator for future use. + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {}) + %zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {}) + %view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {}) + + The one without bias is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {}) + %zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {}) + %mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {}) + %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {}) + %mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {}) + %detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs + """ + + has_bias: bool = False + + input_tensor = args[0].data + output_tensor = args[2].data + if len(args) == 4: + weight_tensors = [args[1].data, args[3].data] + else: + weight_tensors = [args[1].data] + + # process the dimension of input and output + if len(input_tensor.shape) > 2: + input_tensor: torch.Tensor + input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) + + if len(output_tensor.shape) > 2: + output_tensor: torch.Tensor + output_tensor = output_tensor.view(-1, output_tensor.shape[-1]) + + if len(weight_tensors) > 1: + has_bias = True + if len(weight_tensors[0].shape) == 2: + weight_tensor, bias_tensor = weight_tensors + else: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor = weight_tensors[0] + + if has_bias: + # calculate cost with bias + # the fwd op with compute cost is addmm + # the bwd op with compute cost is mm * 2 and sum.dim_IntList + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + else: + # calculate cost without bias + # the fwd op with compute cost is mm + # the bwd op with compute cost is mm * 2 + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ + flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: Linear don't have buffer and temp in forward and backward phase + # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=activation_size(weight_tensor), + temp=0, + buffer=0) + + # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]), + parameter=activation_size(weight_tensor), + temp=0, + buffer=0) + + # total cost is to sum the forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py new file mode 100644 index 000000000..9b34332db --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -0,0 +1,103 @@ +from typing import Callable, Dict, List, Tuple, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.tensor.sharding_spec import ShardingSpec + +from ..registry import meta_register + +__all__ = ['batchnormnd_meta_info'] + + +@meta_register.register(torch.nn.BatchNorm1d) +@meta_register.register(torch.nn.BatchNorm2d) +@meta_register.register(torch.nn.BatchNorm3d) +def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator + The aten graph of BatchNorm2d is like + + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {}) + %cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {}) + %detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {}) + %detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {}) + %detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {}) + %detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {}) + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + weight_tensor = next(filter(lambda x: x.name == "weight", args)).data + bias_tensor = next(filter(lambda x: x.name == "bias", args)).data + mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data + var_tensor = next(filter(lambda x: x.name == "running_var", args)).data + num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data + + # construct fwd args + # the fwd inputs are input, weight, bias, running_mean, running_var and some other args + # indicating the status of the module + # the fwd outputs are output, saved mean, saved inv std and num batches tracked + fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5] + fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch] + + # construct bwd args + # the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean, + # saved inv std and some other args indicating the status of the module + # the bwd outputs are input grad, weight grad and bias grad + bwd_in_args = [ + output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + ] + bwd_out_args = [input_tensor, weight_tensor, bias_tensor] + + # calculate cost + fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # the fwd activation cost is output plus saved mean and saved inv std + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=0, + buffer=activation_size([mean_tensor, var_tensor])) + + # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean + # and saved inv std during backward phase + bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]), + parameter=activation_size([weight_tensor, bias_tensor]), + temp=activation_size([mean_tensor, var_tensor]), + buffer=activation_size([mean_tensor, var_tensor])) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py new file mode 100644 index 000000000..79780c92e --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -0,0 +1,134 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["avgpool_meta_info", "maxpool_meta_info"] + + +@meta_register.register(torch.nn.AdaptiveAvgPool1d) +@meta_register.register(torch.nn.AdaptiveAvgPool2d) +@meta_register.register(torch.nn.AdaptiveAvgPool3d) +@meta_register.register(torch.flatten) +def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for AdaptiveAvgPool + The aten graph of AdaptiveAvgPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {}) + %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = args[0].data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = kwargs.get("inplace", False) + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # calculate cost + # the fwd op with compute cost is _adaptive_avg_pool2d.default + # the bwd op with compute cost is _adaptive_avg_pool2d_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.nn.MaxPool1d) +@meta_register.register(torch.nn.MaxPool2d) +@meta_register.register(torch.nn.MaxPool3d) +def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Meta info for MaxPool + The aten graph of MaxPool is + graph(): + %input_2 : [#users=2] = placeholder[target=placeholder](default=) + %max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {}) + %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) + %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {}) + %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {}) + %max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {}) + %detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {}) + %detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {}) + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # construct forward args for flop mapping + fwd_in_args = [input_tensor] + fwd_out_args = [output_tensor] + + # construct backward args for flop mapping + bwd_in_args = [output_tensor] + bwd_out_args = [input_tensor] + + # construct index matrix + index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64) + + # calculate cost + # the fwd op with compute cost is max_pool2d_with_indices.default + # the bwd op with compute cost is max_pool2d_with_indices_backward.default + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args) + bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args) + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: the index matrix will be discarded in backward phase + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix])) + + # temp memory for backward is the index matrix to be discarded + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix), + temp=activation_size(index_matrix)) + + # total cost + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) + + mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py new file mode 100644 index 000000000..218187768 --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -0,0 +1,117 @@ +from typing import Callable, List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION +from .registry import meta_register + +__all__ = ['MetaInfo'] + + +class MetaInfo: + """MetaInfo class + This class is used to store meta info based on sharding strategy and the given + target function. + """ + + def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None: + # compute cost of forward and backward computation + self.compute_cost: TrainCycleItem + + # compute memory cost of forward and backward phase + self.memory_cost: TrainCycleItem + + # list of input tensors + self.fwd_in: List[torch.Tensor] + + # list of buffer tensors + self.fwd_buffer: List[torch.Tensor] + + # list of output tensors + self.fwd_out: List[torch.Tensor] + + # sharding strategy + self._strategy = strategy + + # target function + self._target = target + + # compute metainfo if possible + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + @property + def strategy(self) -> ShardingStrategy: + return self._strategy + + @property + def target(self) -> Callable: + return self._target + + @strategy.setter + def strategy(self, strategy: ShardingStrategy) -> None: + self._strategy = strategy + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + @target.setter + def target(self, target: Callable) -> None: + self._target = target + if self._strategy is not None and self._target is not None: + self.compute_metainfo() + + def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Compute sharded opdata based on the given data and sharding spec. + """ + return OperationData(name=operation_data.name, + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape) + + def compute_metainfo(self): + """ + Compute meta info based on sharding strategy and the given target function. + """ + assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ + f"Meta info for {self._target} is not registered." + if meta_register.has(self._target.__class__): + # module + meta_func = meta_register.get(self._target.__class__) + + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + else: + # function + meta_func = meta_register.get(self._target) + + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + + # construct args for meta_func + args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] + + # construct kwargs + if self.target in INPLACE_MODULE: + kwargs = {'inplace': self.target.inplace} + elif self.target in INPLACE_OPS: + kwargs = {'inplace': True} + else: + kwargs = {'inplace': False} + + # compute metainfo with meta_func + self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) + + # process corner case for NO_SAVE_ACTIVATION + if not save_fwd_in: + self.fwd_in = [] diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py new file mode 100644 index 000000000..46350c4dd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -0,0 +1,32 @@ +__all__ = ['Registry'] + + +class Registry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store, f'{source} not found in the {self.name} registry' + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_register = Registry('meta') diff --git a/colossalai/auto_parallel/passes/__init__.py b/colossalai/auto_parallel/passes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py new file mode 100644 index 000000000..ab3acb056 --- /dev/null +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -0,0 +1,113 @@ +from typing import Dict + +import torch +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> MetaInfo: + # get comm_action_sequence and total_cost from shape_consistency_manager + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + meta_info = MetaInfo() + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for MetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + # extract user that has _meta_data and extract element length + input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + element_length = input_node._meta_data.element_size() + + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for MetaInfo + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + # get tensor shape for MetaInfo + origin_sharding_spec: ShardingSpec + target_sharding_spec: ShardingSpec + input_shape = origin_sharding_spec.get_sharded_shape_per_device() + output_shape = target_sharding_spec.get_sharded_shape_per_device() + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + """ + + # extract node index and user node index + args = node.args + node_index, user_node_index = args[3], args[4] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] + + return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: + # extract node_index and op_data_name + node_index, op_data_name = node.args[2], node.args[3] + + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + # this case is for all_reduce, there will be no memory cost + meta_info = MetaInfo() + meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) + output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + element_length = output_node._meta_data.element_size() + + total_cost = comm_action.comm_spec.get_comm_cost() + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + else: + # this case will be handled by shape consistency manager + origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ + 'tgt_spec'] + meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + return meta_info + + +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: + """ + The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. + """ + for node in gm.graph.nodes: + if node.target == runtime_apply: + setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + elif node.target == runtime_comm_spec_apply: + setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + else: + pass + return gm diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py new file mode 100644 index 000000000..b86088474 --- /dev/null +++ b/colossalai/auto_parallel/passes/constants.py @@ -0,0 +1,8 @@ +import torch + +OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py new file mode 100644 index 000000000..f7e07ef1e --- /dev/null +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -0,0 +1,165 @@ +import uuid +from dataclasses import asdict +from typing import List + +import torch +import torch.fx +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +@compatibility(is_backward_compatible=False) +class MetaInfoProp: + + def __init__(self, module: GraphModule) -> None: + self.module = module + self.func_dict = { + 'placeholder': self.placeholder_handler, + 'get_attr': self.get_attr_handler, + 'output': self.output_handler, + 'call_function': self.node_handler, + 'call_module': self.node_handler, + 'call_method': self.node_handler, + } + + def _set_data_ptr(self, x): + """ + Set uuid to tensor + """ + if isinstance(x, torch.Tensor): + if not x.data_ptr(): + data_ptr = uuid.uuid4() + x.data_ptr = lambda: data_ptr + + def _is_inplace(self, node: Node): + """ + Check if the node is inplace operation. + """ + if node.op == 'call_module': + return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD + elif node.op == "call_function": + return node.target in OUTPUT_SAVED_OPS + return False + + def run(self) -> GraphModule: + """ + Run the meta information propagation pass on the module. + """ + for node in self.module.graph.nodes: + node: Node + self.func_dict[node.op](node) + + @compatibility(is_backward_compatible=False) + def placeholder_handler(self, node: Node) -> None: + """ + Handle the placeholder node. + """ + graph_info = GraphInfo() + out = _normalize_tuple(getattr(node, '_meta_data', None)) + graph_info.fwd_out = list(out) if out[0] is not None else [] + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def get_attr_handler(self, node: Node) -> None: + """ + Handle the get_attr node. + """ + graph_info = GraphInfo() + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def output_handler(self, node: Node) -> None: + """ + Handle the output node. + """ + graph_info = GraphInfo() + output_tensors = [] + for par in node._input_nodes: + if par.meta: + output_tensors += par.meta["fwd_out"] + graph_info.fwd_in = output_tensors + node.meta = {**asdict(graph_info)} + + @compatibility(is_backward_compatible=False) + def node_handler(self, node: Node) -> None: + """ + Handle other kind of nodes + """ + assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" + graph_info = GraphInfo() + meta_info = node.best_metainfo + meta_info: MetaInfo + + # set data_ptr for input_tensor in MetaInfo class + input_tensors: List[torch.Tensor] = meta_info.fwd_in + buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer + output_tensors: List[torch.Tensor] = meta_info.fwd_out + + if self._is_inplace(node): + # inplace operation will not create new tensor, and it only has one parent node + # TODO: Verify this observation + # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node + parent_node = list(node._input_nodes.keys())[0] + parent_tensor = parent_node.meta.get("fwd_out")[0] + parent_tensor: torch.Tensor + for tensor in input_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in buffer_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in output_tensors: + tensor.data_ptr = parent_tensor.data_ptr + + else: + for par in node._input_nodes: + # set data_ptr for the input_tensor of current node from the output_tensor of its parent node + for tensor in par.meta.get("fwd_out", []): + tensor: torch.Tensor + target_input_tensor = next( + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + if target_input_tensor is not None: + target_input_tensor.data_ptr = tensor.data_ptr + + # set data_ptr for tensor in input_tensor that is not set + for tensor in input_tensors: + if not tensor.data_ptr(): + self._set_data_ptr(tensor) + + # set data_ptr for buffer_tensor + for tensor in buffer_tensors: + self._set_data_ptr(tensor) + + # set data_ptr for output_tensor + for tensor in output_tensors: + self._set_data_ptr(tensor) + + # attach them to graph_info + graph_info.fwd_in = input_tensors + graph_info.fwd_tmp = buffer_tensors + graph_info.fwd_out = output_tensors + + # fetch other memory informations + memory_cost = meta_info.memory_cost + graph_info.fwd_mem_tmp = memory_cost.fwd.temp + graph_info.fwd_mem_out = memory_cost.fwd.activation + graph_info.bwd_mem_tmp = memory_cost.bwd.temp + graph_info.bwd_mem_out = memory_cost.bwd.activation + + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + compute_cost = meta_info.compute_cost + graph_info.fwd_time = compute_cost.fwd + graph_info.bwd_time = compute_cost.bwd + + node.meta = {**asdict(graph_info)} diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py new file mode 100644 index 000000000..7f2aac42b --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -0,0 +1,221 @@ +from copy import deepcopy +from typing import Dict, List + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into + the user node expected form. + """ + origin_sharding_spec = origin_dict[node_index] + target_sharding_spec = input_dict[node_index][user_node_index] + return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) + + +def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, + user_node_index: int): + """ + This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list + is converted into the user node expected form. + """ + rst = [] + for index, (origin_sharding_spec, + target_sharding_spec) in enumerate(zip(origin_dict[node_index], + input_dict[node_index][user_node_index])): + rst.append( + shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, + target_sharding_spec)) + rst = type(node)(rst) + return rst + + +def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): + """ + This method will be invoked during runtime to apply the comm action following the instruction of comm spec. + """ + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + rst = comm_action.comm_spec.covert_spec_to_action(tensor) + else: + origin_sharding_spec = comm_action.comm_spec['src_spec'] + tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) + return rst + + +def _preprocess_graph(nodes: List[Node]): + """ + This method is used to extract all the placeholders with sharding information, + and mapping the nodes into the index of the origin graph. + """ + # mapping the node into the origin graph index + node_to_index_dict = {} + index = 0 + for node in nodes: + if node.target == 'sharding_spec_convert_dict': + input_dict_node = node + continue + if node.target == 'origin_node_sharding_spec_dict': + origin_dict_node = node + continue + if node.target == 'comm_actions_dict': + comm_actions_dict_node = node + continue + if not hasattr(node, 'best_strategy'): + continue + node_to_index_dict[node] = index + index += 1 + + return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict + + +def _shape_consistency_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the shape consistency node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): + if isinstance(node.sharding_spec, (list, tuple)): + assert isinstance( + node.target_sharding_specs, + (list, + tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + total_difference = 0 + for sharding_spec, target_sharding_spec in zip(node.sharding_spec, + node.target_sharding_specs[user_node_index]): + total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) + if total_difference == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + + else: + assert isinstance(node.sharding_spec, + ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: + continue + with mod_graph.inserting_before(user_node): + shape_consistency_node = mod_graph.create_node('call_function', + runtime_apply, + args=(node, origin_dict_node, input_dict_node, + node_to_index_dict[node], user_node_index)) + + new_args = list(user_node.args) + new_kwargs = dict(user_node.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with shape_consistency_node + origin_index_args = new_args.index(node) + new_args[origin_index_args] = shape_consistency_node + user_node.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with shape_consistency_node + new_kwargs[str(node)] = shape_consistency_node + user_node.kwargs = new_kwargs + + return gm + + +def _comm_spec_apply(gm: torch.fx.GraphModule): + """ + This pass is used to add the comm spec apply node to the origin graph. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) + + for node in nodes: + if not hasattr(node, 'best_strategy') or node.op == 'output': + continue + + comm_actions = node.best_strategy.communication_actions + for op_data, comm_action in comm_actions.items(): + + if comm_action.comm_type == CommType.HOOK: + continue + if comm_action.comm_type == CommType.BEFORE: + if op_data.type == OperationDataType.OUTPUT: + comm_object = node + elif comm_action.key_for_kwarg is not None: + comm_object = node.kwargs[comm_action.key_for_kwarg] + else: + comm_object = node.args[comm_action.arg_index] + with mod_graph.inserting_before(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + # the origin node may be a positional argument or key word argument of user node + if comm_action.key_for_kwarg is not None: + # substitute the origin node with comm_spec_apply_node + new_kwargs = dict(node.kwargs) + new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node + node.kwargs = new_kwargs + else: + # substitute the origin node with comm_spec_apply_node + new_args = list(node.args) + new_args[comm_action.arg_index] = comm_spec_apply_node + node.args = tuple(new_args) + + elif comm_action.comm_type == CommType.AFTER: + with mod_graph.inserting_after(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + user_list = list(node.users.keys()) + for user in user_list: + if user == comm_spec_apply_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with comm_spec_apply_node + new_args[new_args.index(node)] = comm_spec_apply_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with comm_spec_apply_node + new_kwargs[str(node)] = comm_spec_apply_node + user.kwargs = new_kwargs + return gm + + +def runtime_apply_pass(gm: torch.fx.GraphModule): + """ + The method manages all the passes acting on the distributed training runtime. + """ + gm = _shape_consistency_apply(gm) + gm = _comm_spec_apply(gm) + + return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py new file mode 100644 index 000000000..f9b890263 --- /dev/null +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -0,0 +1,471 @@ +import operator +from copy import deepcopy +from typing import Dict, List, Union + +import torch +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationDataType, + ShardingStrategy, +) +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.comm_spec import _all_reduce +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def size_processing(size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None): + """ + This method will be invoked during runtime to convert size node value depending on distributed information. + """ + if target_dim is not None: + assert isinstance(size, int) + if target_dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[target_dim]: + total_shard_size *= device_mesh_info[shard_dim] + size = size * total_shard_size + + else: + size = list(size) + for dim, dim_size in enumerate(size): + if dim in dim_partition_dict: + total_shard_size = 1 + for shard_dim in dim_partition_dict[dim]: + total_shard_size *= device_mesh_info[shard_dim] + size[dim] = dim_size * total_shard_size + size = torch.Size(size) + + return size + + +def _solution_annotatation(gm: torch.fx.GraphModule, + solution: List[int], + strategies_constructor: StrategiesConstructor = None): + """ + This method is used to stick the solution strategy to the nodes and add the information + required in runtime into graph as placeholder nodes. + """ + mod_graph = gm.graph + # TODO: In future PR, strategies_constructor should be a required argument, + # instead of optional argument. This is because we don't need to consider nodes with + # no strategy in runtime preparation pass. + if strategies_constructor is not None: + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes + else: + nodes = tuple(mod_graph.nodes) + no_strategy_nodes = [] + + # the dict to get origin sharding spec of node + origin_node_sharding_spec_dict = {} + for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): + strategies_vector = node.strategies_vector + # stick the solution strategy to the corresponding node + setattr(node, 'best_strategy', strategies_vector[strategy_index]) + setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( + str(node)) + + # attach the corresponding metainfo if node has the attribute `metainfo_vector` + if hasattr(node, 'metainfo_vector'): + setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index]) + + # the dict to get input sharding specs of user node + sharding_spec_convert_dict = {} + # the dict to record comm actions of nodes + comm_actions_dict = {} + for index, node in enumerate(nodes): + target_sharding_specs = [] + for user_node in node.strategies_vector.successor_nodes: + if user_node in no_strategy_nodes: + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name)) + else: + target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) + target_sharding_specs.append(target_sharding_spec) + sharding_spec_convert_dict[index] = target_sharding_specs + setattr(node, 'target_sharding_specs', target_sharding_specs) + # the get_attr node strategy is kind of pending strategy, which means we will change it + # to the same strategy of the user node. + if node.op == 'get_attr': + assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + target_node = node.strategies_vector.successor_nodes[0] + node_name = str(node) + if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + node_name = str(target_node) + target_node = target_node.strategies_vector.successor_nodes[0] + user_strategy = target_node.best_strategy + op_data_in_user = user_strategy.get_op_data_by_name(node_name) + origin_pending_strategy = node.best_strategy + origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node)) + + new_communication_actions = {} + if op_data_in_user in user_strategy.communication_actions: + new_communication_action = user_strategy.communication_actions.pop(op_data_in_user) + new_communication_action.arg_index = 0 + new_communication_actions[origin_op_data] = new_communication_action + node.best_strategy.communication_actions = new_communication_actions + + comm_action_dict = {} + for op_data, comm_action in node.best_strategy.communication_actions.items(): + comm_action_dict[op_data.name] = comm_action + comm_actions_dict[index] = comm_action_dict + + # add above dicts into graph + for node in nodes: + if node.op != 'placeholder': + with mod_graph.inserting_before(node): + input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') + origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + break + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict + + +def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + In the auto parallel system, tensors may get shard on different devices, so the size of tensors + need to be converted to the size of original tensor and managed by the users, such as torch.view, + torch.reshape, etc. These nodes have enough information like input sharding_spec and + output sharding_spec to decide how to convert the size value. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + node_pairs = {} + + for node in nodes: + + if node.op == 'call_method' and node.target == 'size': + # extract useful information from size node + # dim_partition_dict will instruct the size value on which + # dimension should be enlarged. + sharding_spec = node.args[0].sharding_spec + dim_partition_dict = sharding_spec.dim_partition_dict + + # there are two usages of torch.Tensor.size: + # tensor.size() + # tensor.size(dim) + # if a target_dim is assigned, then the output will be + # in type of int, instead of torch.Size + target_dim = None + if len(node.args) > 1: + target_dim = node.args[1] + if target_dim < 0: + target_dim += node.args[0]._meta_data.dim() + + # DeviceMesh information instructs the scaling of the size value + device_mesh_info = {} + for dim, dim_size in enumerate(device_mesh.mesh_shape): + device_mesh_info[dim] = dim_size + + with mod_graph.inserting_after(node): + size_processing_node = mod_graph.create_node('call_function', + size_processing, + args=(node, dim_partition_dict, device_mesh_info, + target_dim, node.name)) + # store original node and processing node pair in node_pairs dictioanry + # It will be used to replace the original node with processing node in slice object + node_pairs[node] = size_processing_node + size_processing_node._meta_data = node._meta_data + + user_list = list(node.users.keys()) + for user in user_list: + if user == size_processing_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with size_processing_node + new_args[new_args.index(node)] = size_processing_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with size_processing_node + new_kwargs[str(node)] = size_processing_node + user.kwargs = new_kwargs + + if node.op == 'call_function' and node.target == operator.getitem: + + getitem_index = node.args[1] + # slice object is quite special in torch.fx graph, + # On one side, we treat slice object same as type of int, + # so we do not create a node for slice object. On the other side, + # slice object could take fx.Node as its argument. And the user + # relationship cannot be tracked in fx graph. + # Therefore, I record the node_pairs in this pass, and use the it + # to replace the original node argument inside the slice object if + # it has been processed in above pass. + + # There are three main usages of operator.getitem: + # getitem(input, int) + # getitem(input, slice) + # getitem(input, Tuple[slice]) + # In this pass, we need process the last two cases because + # node arguments may potentially appear in these cases. + if isinstance(getitem_index, slice): + new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step + if getitem_index.start in node_pairs: + new_start = node_pairs[getitem_index.start] + elif getitem_index.stop in node_pairs: + new_stop = node_pairs[getitem_index.stop] + elif getitem_index.step in node_pairs: + new_step = node_pairs[getitem_index.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_args = (node.args[0], new_slice_item) + node.args = new_args + + elif isinstance(getitem_index, (tuple, list)): + if not isinstance(getitem_index[0], slice): + continue + new_slice_items = [] + + for slice_item in getitem_index: + if slice_item is None: + new_slice_items.append(None) + continue + + new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step + + if slice_item.start in node_pairs: + new_start = node_pairs[slice_item.start] + elif slice_item.stop in node_pairs: + new_stop = node_pairs[slice_item.stop] + elif slice_item.step in node_pairs: + new_step = node_pairs[slice_item.step] + new_slice_item = slice(new_start, new_stop, new_step) + new_slice_items.append(new_slice_item) + + new_args = (node.args[0], tuple(new_slice_items)) + node.args = new_args + + return gm + + +def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + This pass will process node args to adapt the distributed tensor layout. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + # skip the placeholder node added in _solution_annotation pass + if not hasattr(node, 'sharding_spec'): + continue + + def _process_sharding_spec(sharding_spec): + if isinstance(sharding_spec, ShardingSpec): + dim_partition_dict = sharding_spec.dim_partition_dict + device_mesh = sharding_spec.device_mesh + return dim_partition_dict, device_mesh + if sharding_spec is None: + return None, None + assert isinstance(sharding_spec, + (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + + device_mesh = sharding_spec[0].device_mesh + dim_partition_dict = [] + for element in sharding_spec: + dim_partition_dict.append(_process_sharding_spec(element)) + return dim_partition_dict, sharding_spec + + output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec) + new_args = [] + + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + # process the node with (input, *shape) style args + if method in (torch.Tensor.view, torch.Tensor.reshape): + + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, (int, tuple, list)): + new_args.append(arg._meta_data) + else: + new_args.append(arg) + else: + assert isinstance( + arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + new_args.append(arg) + + for dim, shard_dims in output_dim_partition_dict.items(): + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + # There are two ways to use torch.view: + # 1. torch.view(input, *shape) + # 2. torch.view(input, shape) + if isinstance(new_args[1], int): + # we will skip the dim with -1 value + if new_args[dim + 1] == -1: + continue + else: + new_args[dim + 1] //= total_shard_size + else: + new_args[1] = list(new_args[1]) + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + else: + new_args[1][dim] //= total_shard_size + node.args = tuple(new_args) + + elif node.op == 'call_function': + target = node.target + # process the node with (input, torch.Size) style args + if target in (torch.reshape,): + for arg in node.args: + if isinstance(arg, Node): + if isinstance(arg._meta_data, (tuple, list)): + new_args.append(list(arg._meta_data)) + else: + new_args.append(arg) + else: + assert isinstance( + arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.' + new_args.append(list(arg)) + + for dim, shard_dims in output_dim_partition_dict.items(): + # we will skip the dim with -1 value + if new_args[1][dim] == -1: + continue + total_shard_size = 1 + for shard_dim in shard_dims: + total_shard_size *= device_mesh.shape[shard_dim] + new_args[1][dim] //= total_shard_size + node.args = tuple(new_args) + + return gm + + +def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): + """ + Apply the sharding action to the module parameters and buffers following the + instructions of solver solution. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + # This stream is created for overlaping the communication and computation. + reduction_stream = torch.cuda.Stream() + for node in nodes: + if node.op == 'call_module': + target_module = node.graph.owning_module.get_submodule(node.target) + # TODO: we need to do more actions to take care of the shared parameters. + if hasattr(target_module, 'processed') and target_module.processed: + continue + setattr(target_module, 'processed', True) + for name, param in target_module.named_parameters(): + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + # apply the sharding spec of parameters + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) + setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param.data = shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec).detach().clone() + + setattr(target_module, name, param) + comm_actions = node.best_strategy.communication_actions + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: + + def wrapper(param, comm_spec): + + def hook_fn(grad): + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(param, comm_spec_to_use) + + sharded_buffer_dict = {} + # apply the sharding spec of buffers + for name, buffer in target_module.named_buffers(): + origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) + setattr(buffer, 'sharding_spec', origin_sharding_spec) + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) + buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) + sharded_buffer_dict[name] = buffer_sharded + + for name, buffer_sharded in sharded_buffer_dict.items(): + setattr(target_module, name, buffer_sharded.detach().clone()) + + if node.op == 'get_attr': + root = node.graph.owning_module + atoms = node.target.split(".") + attr_len = len(atoms) + if attr_len == 1: + target_module = root + target = getattr(root, atoms[0]) + else: + target_module = root + for atom in atoms[:-1]: + target_module = getattr(target_module, atom) + target = getattr(target_module, atoms[-1]) + + target_sharding_spec = node.sharding_spec + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) + setattr(target, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + target.data = shape_consistency_manager.apply_for_autoparallel_runtime( + target.data, target.sharding_spec, target_sharding_spec).detach().clone() + + assert hasattr(target_module, atoms[-1]) + setattr(target_module, atoms[-1], target) + + comm_actions = node.best_strategy.communication_actions + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + + def wrapper(param, comm_spec): + + def hook_fn(grad): + _all_reduce(grad, comm_spec, async_op=False) + + param.register_hook(hook_fn) + + wrapper(target, comm_spec_to_use) + return gm + + +def implicit_comm_action_apply(gm: torch.fx.GraphModule): + """ + replace the origin kernel into kernel with implicit communication inside. + """ + pass + + +def runtime_preparation_pass(gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor = None): + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( + gm, solution, strategies_constructor) + gm = _size_value_converting(gm, device_mesh) + gm = _node_args_converting(gm, device_mesh) + # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. + # gm = implicit_comm_action_apply(gm) + gm = _module_params_sharding(gm, device_mesh) + + return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 91c20d343..99c124934 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -1,6 +1,7 @@ -import torch import operator +import torch + __all__ = [ 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', @@ -25,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [ # TODO: contiguous maybe need some extra processes. torch.Tensor.contiguous ] -RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] +RESHAPE_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.transpose, + torch.split, + torch.permute, + operator.getitem, +] RESHAPE_METHOD_OP = [ torch.Tensor.view, torch.Tensor.unsqueeze, @@ -35,7 +43,7 @@ RESHAPE_METHOD_OP = [ ] BCAST_FUNC_OP = [ torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh + operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow ] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py index a081ce69c..bd47f2adf 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py @@ -1,6 +1,6 @@ -from .options import SolverOptions -from .strategies_constructor import StrategiesConstructor -from .sharding_strategy import ShardingStrategy, StrategiesVector from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .options import SolverOptions +from .sharding_strategy import ShardingStrategy, StrategiesVector from .solver import Solver -from .graph_analysis import GraphAnalyser \ No newline at end of file +from .strategies_constructor import StrategiesConstructor diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py index a72d97554..d6af7ad57 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py @@ -5,10 +5,11 @@ from functools import reduce from typing import Dict, List, Optional, Union import torch +from torch.fx.node import Node + from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from torch.fx.node import Node from .constants import INFINITY_COST @@ -17,7 +18,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict. - + Args: input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. @@ -58,7 +59,7 @@ def generate_resharding_costs(nodes: List[Node], nodes (List[Node]): a list of nodes sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. - dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. ''' # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py index c41ca6370..d8952040d 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py @@ -3,9 +3,9 @@ import warnings from functools import reduce import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler @@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler): Argument: sharding_size_forward(int): The forward activation will be divided into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will + sharding_size_backward_activation(int): The backward activation will be divided into sharding_size_backward_activation number partions. sharding_size_weight(int): The backward weight will be divided into sharding_size_weight number partions. Return: - memory_cost(Tuple[float]): Memory cost per device with this + memory_cost(Tuple[float]): Memory cost per device with this specific strategy, the first element of this tuple is forward memory cost, and the second element of this tuple is backward memory cost. - memory_cost_forward(float): Memory cost of forward activation per + memory_cost_forward(float): Memory cost of forward activation per device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation + memory_cost_backward_activation(float): Memory cost of backward activation per device with this specific strategy. ''' # compute the memory cost of this strategy @@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler): # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy_into_strategies_vector() for strategy in conv_handler.strategies_vector: print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') - + Output: S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py index 4feeacd98..1f2281cc4 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py @@ -6,9 +6,9 @@ from typing import List import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from .operator_handler import OperatorHandler @@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator): class MatMulStrategyGenerator(StrategyGenerator): """ - MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is + MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. A matmul can be formulated as [n, p] x [p, q] = [n, q] Args: - is_linear (bool): whether this generator is used for nn.Linear and F.linear. + is_linear (bool): whether this generator is used for nn.Linear and F.linear. This will incur extra transformation of the dim partitioning as the weight is transposed. """ @@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator): """ Generate sharding strategies for the batched matrix multiplication. - A batched matrix multiplication can be viewed as + A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] """ @@ -431,7 +431,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -451,7 +451,7 @@ class DotHandler(OperatorHandler): # create and register strategy sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -473,7 +473,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -491,7 +491,7 @@ class DotHandler(OperatorHandler): communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) communication_cost = communication_cost_activation_forward + communication_cost_grad_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -510,7 +510,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -529,7 +529,7 @@ class DotHandler(OperatorHandler): communication_cost = communication_cost_activation_backward + communication_cost_activation_forward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -548,7 +548,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -564,7 +564,7 @@ class DotHandler(OperatorHandler): # compute the communication cost of this strategy communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -583,7 +583,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -600,7 +600,7 @@ class DotHandler(OperatorHandler): communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -619,7 +619,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -636,7 +636,7 @@ class DotHandler(OperatorHandler): communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) communication_cost = communication_cost_weight_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -655,7 +655,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -673,7 +673,7 @@ class DotHandler(OperatorHandler): activation_memory_cost, 0) communication_cost = communication_cost_forward_activation sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -692,7 +692,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -709,7 +709,7 @@ class DotHandler(OperatorHandler): input_grad_memory_cost, 0) communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py index c75fdbbb6..8062d0f4b 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py @@ -2,10 +2,14 @@ import operator from functools import reduce import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding, - generate_sharding_size, ignore_sharding_exception) -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, + ignore_sharding_exception, +) +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler @@ -63,19 +67,19 @@ class LayerNormHandler(OperatorHandler): Argument: sharding_size_forward(int): The forward activation will be divided into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will + sharding_size_backward_activation(int): The backward activation will be divided into sharding_size_backward_activation number partions. sharding_size_weight(int): The backward weight will be divided into sharding_size_weight number partions. Return: - memory_cost(Tuple[float]): Memory cost per device with this + memory_cost(Tuple[float]): Memory cost per device with this specific strategy, the first element of this tuple is forward memory cost, and the second element of this tuple is backward memory cost. - memory_cost_forward(float): Memory cost of forward activation per + memory_cost_forward(float): Memory cost of forward activation per device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation + memory_cost_backward_activation(float): Memory cost of backward activation per device with this specific strategy. ''' # compute the memory cost of this strategy @@ -216,7 +220,7 @@ class LayerNormHandler(OperatorHandler): norm_handler.register_strategy() for strategy in norm_handler.strategies_vector: print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') - + Output: RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py index 79f72d8d5..b120cc16b 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py @@ -1,14 +1,16 @@ +from abc import ABC, abstractmethod +from typing import Dict, List from webbrowser import Opera + import torch import torch.nn as nn -from abc import ABC, abstractmethod from torch.fx.node import Node -from typing import Dict, List + +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec -from .._utils import generate_resharding_costs, generate_sharding_spec -from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from .._utils import generate_resharding_costs, generate_sharding_spec from ..sharding_strategy import StrategiesVector __all__ = ['OperatorHandler'] @@ -60,7 +62,7 @@ class OperatorHandler(ABC): @abstractmethod def register_strategy(self) -> StrategiesVector: """ - Register + Register """ pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py index 2d3967025..d4ccc8a9c 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py @@ -4,9 +4,9 @@ import warnings from copy import deepcopy import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py index 528d37977..7bebde9d6 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py @@ -1,18 +1,21 @@ +import builtins +import math +import operator +from copy import deepcopy +from typing import Dict, List + +import torch from torch.fx import Graph, Node -from colossalai.tensor.sharding_spec import ShardingSpec + from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +from ._utils import generate_resharding_costs, generate_sharding_spec +from .constants import * +from .op_handler import * from .options import SolverOptions from .sharding_strategy import ShardingStrategy, StrategiesVector -from .op_handler import * -from .constants import * -from copy import deepcopy -import math -import torch -import operator -from typing import Dict, List -from ._utils import generate_sharding_spec, generate_resharding_costs -import builtins __all__ = ['StrategiesConstructor'] diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py new file mode 100644 index 000000000..0dce2564c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -0,0 +1,275 @@ +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.fx import GraphModule +from torch.fx.graph import Graph + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class ModuleWrapper(nn.Module): + ''' + This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict + into the forward function. + ''' + + def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): + ''' + Args: + module: the original module + sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. + origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. + comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. + ''' + super(ModuleWrapper, self).__init__() + self.module = module + self.sharding_spec_dict = sharding_spec_dict + self.origin_spec_dict = origin_spec_dict + self.comm_actions_dict = comm_actions_dict + + def forward(self, *args, **kwargs): + return self.module(*args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs) + + +def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): + ''' + This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. + ''' + # TODO: implement this function + pass + + +def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]): + ''' + This method is used to search the best logical mesh shape for the given world size + based on the alpha_beta_dict. + + For example: + if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1). + ''' + # TODO: implement this function + return (world_size, 1) + + +def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): + ''' + This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape + from the alpha_beta_dict. These two values will be used to estimate the communication cost. + ''' + # TODO: implement this function + pass + + +def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): + ''' + This method is used to build the strategy_constructor for the given graph. + After this method, each node in the graph will have a strategies_vector which + is constructed by the related node handler. + ''' + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + return strategies_constructor + + +def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): + ''' + This method is used to solve the best solution for the given graph. + The solution is a list of integers, each integer represents the best strategy index of the corresponding node. + ''' + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + cost_graph = CostGraph(strategy_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + return solution + + +def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor): + ''' + This method is used to transform the original graph to the sharded graph. + The model parameters will be sharded according to the solution and the grad hooks + will be added to the sharded graph using the runtime_preparation_pass. + The communication node will be added into the graph using the runtime_apply_pass. + ''' + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) + + return gm, sharding_spec_dicts + + +def initialize_device_mesh(world_size: int = -1, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None): + ''' + This method is used to initialize the device mesh. + + Args: + world_size(optional): the size of device mesh. If the world_size is -1, + the world size will be set to the number of GPUs in the current machine. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be + generated by search_best_logical_mesh_shape function. + ''' + # if world_size is not set, use the world size from torch.distributed + if world_size == -1: + world_size = dist.get_world_size() + device1d = [i for i in range(world_size)] + + if alpha_beta_dict is None: + # if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device + alpha_beta_dict = profile_alpha_beta(device1d) + + if logical_mesh_shape is None: + # search for the best logical mesh shape + logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict) + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape) + physical_mesh = torch.tensor(device1d) + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, + mesh_shape=logical_mesh_shape, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True) + return device_mesh + + +def initialize_model(model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None, + return_solution: bool = False): + ''' + This method is used to initialize the sharded model which could be used as normal pytorch model. + + Args: + model: the model to be sharded. + meta_args: the meta_args is used to specify the input shapes of the model. + device_mesh: the device mesh to execute the model. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. The returned + solution will be used to debug or help to analyze the sharding result. Therefore, we will not just + return a series of integers, but return the best strategies. + ''' + tracer = ColoTracer() + + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + strategies_constructor = build_strategy_constructor(graph, device_mesh) + if load_solver_solution: + solution = torch.load(solution_path) + else: + solution = solve_solution(gm, strategies_constructor, memory_budget) + if save_solver_solution: + torch.save(solution, solution_path) + + gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) + model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) + + if return_solution: + solution_to_return = [] + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + return model_to_return, solution_to_return + else: + return model_to_return + + +def autoparallelize(model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + return_solution: bool = False, + memory_budget: float = -1.0): + ''' + This method is used to initialize the device mesh, extract the meta_args, and + use them to create a sharded model. + + Args: + model: the model to be sharded. + meta_args(optional): the meta_args is used to specify the input shapes of the model. + If the meta_args is None, the meta_args will be extracted from the data_loader. + data_loader(optional): the data_loader to be used in normal training loop. + data_process_func(optional): the data_process_func is used to process the data from the data_loader. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be + generated by search_best_logical_mesh_shape function. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solver_solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + ''' + device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape) + if meta_args is None: + meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) + + rst_to_unpack = initialize_model(model, + meta_args, + device_mesh, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solver_solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget) + + if return_solution: + model, solution = rst_to_unpack + return model, solution + else: + model = rst_to_unpack + return model diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index b9227e2ec..a5e3f649a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,19 +1,31 @@ +from .addmm_handler import ADDMMFunctionHandler from .batch_norm_handler import BatchNormModuleHandler +from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler +from .experimental import PermuteHandler, ViewHandler +from .getattr_handler import GetattrHandler +from .getitem_handler import GetItemHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler +from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler -from .output_handler import OuputHandler -from .placeholder_handler import PlacehodlerHandler +from .output_handler import OutputHandler +from .placeholder_handler import PlaceholderHandler from .registry import operator_registry from .reshape_handler import ReshapeHandler +from .softmax_handler import SoftmaxHandler +from .sum_handler import SumHandler +from .tensor_constructor_handler import TensorConstructorHandler from .unary_elementwise_handler import UnaryElementwiseHandler from .where_handler import WhereHandler __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'operator_registry' + 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', + 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py new file mode 100644 index 000000000..da0d199c5 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['ADDMMFunctionHandler'] + + +@operator_registry.register(torch.addmm) +@operator_registry.register(torch.Tensor.addmm) +class ADDMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: + if isinstance(tensor, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + return data_type + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + + # input operand + input_data = self.node.args[1]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[1]), + type=self._infer_op_data_type(input_data), + data=input_data) + + # other operand + other_data = self.node.args[2]._meta_data + physical_other_operand = OperationData(name=str(self.node.args[2]), + type=self._infer_op_data_type(other_data), + data=other_data) + # bias physical shape + bias_logical_shape = self.node._meta_data.shape + bias_data = self.node.args[0]._meta_data + physical_bias_operand = OperationData(name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape) + + # output + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + 'bias': physical_bias_operand + } + + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index 6bdd15d16..57b623b01 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,8 +2,10 @@ from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import ModuleHandler +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector +from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry from .strategy import BatchNormStrategyGenerator, StrategyGenerator @@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler'] @operator_registry.register(torch.nn.BatchNorm1d) @operator_registry.register(torch.nn.BatchNorm2d) @operator_registry.register(torch.nn.BatchNorm3d) -class BatchNormModuleHandler(ModuleHandler): +class BatchNormModuleHandler(MetaInfoModuleHandler): """ A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py new file mode 100644 index 000000000..f510f7477 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -0,0 +1,96 @@ +from typing import Dict, List, Union + +import torch +from torch.fx.node import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..constants import BCAST_FUNC_OP +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import MetaInfoNodeHandler, NodeHandler +from .registry import operator_registry +from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator + +__all__ = ['BinaryElementwiseHandler'] + + +@operator_registry.register(BCAST_FUNC_OP) +class BinaryElementwiseHandler(MetaInfoNodeHandler): + """ + An BinaryBcastOpHandler is a node handler which deals with operations which have two + operands and broadcasting occurs such as torch.add. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + bcast_shape = self.node._meta_data.shape + + def _get_op_data_type(tensor): + if isinstance(tensor, torch.nn.parameter.Parameter): + return OperationDataType.PARAM + else: + return OperationDataType.ARG + + def _get_arg_value(idx): + if isinstance(self.node.args[idx], Node): + meta_data = self.node.args[idx]._meta_data + else: + # this is in fact a real data like int 1 + # but we can deem it as meta data + # as it won't affect the strategy generation + assert isinstance(self.node.args[idx], (int, float)) + meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + return meta_data + + input_meta_data = _get_arg_value(0) + other_meta_data = _get_arg_value(1) + output_meta_data = self.node._meta_data + + input_op_data = OperationData(name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=bcast_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + for op_name, op_data in op_data_mapping.items(): + if not isinstance(op_data.data, torch.Tensor): + # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) + strategy.sharding_specs.pop(op_data) + else: + # convert the logical sharding spec to physical sharding spec if broadcast + # e.g. torch.rand(4, 4) + torch.rand(4) + physical_shape = op_data.data.shape + logical_shape = op_data.logical_shape + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + sharding_spec, logical_shape, physical_shape) + + strategy.sharding_specs[op_data] = sharding_spec + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=op_data, + sharding_spec=sharding_spec) + strategy.communication_actions[op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index 09016d507..9e1d958e1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -2,8 +2,10 @@ from typing import Dict, List, Union import torch -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from ..utils import recover_sharding_spec_for_broadcast_shape +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator @@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler): bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) - bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape, - bias_physical_shape) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 0c00160ef..272b1c856 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,9 +3,9 @@ from typing import Dict, List import torch import torch.nn.functional as F -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..utils import transpose_partition_dim -from .node_handler import ModuleHandler, NodeHandler +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator @@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] @operator_registry.register(torch.nn.Conv1d) @operator_registry.register(torch.nn.Conv2d) @operator_registry.register(torch.nn.Conv3d) -class ConvModuleHandler(ModuleHandler): +class ConvModuleHandler(MetaInfoModuleHandler): """ A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. """ @@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler): @operator_registry.register(F.conv1d) @operator_registry.register(F.conv2d) @operator_registry.register(F.conv3d) -class ConvFunctionHandler(NodeHandler): +class ConvFunctionHandler(MetaInfoNodeHandler): """ A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py new file mode 100644 index 000000000..e154105b6 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -0,0 +1,230 @@ +from typing import Dict, List, Union + +import torch +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim +from colossalai.logging import get_dist_logger +from colossalai.tensor.sharding_spec import ShardingNotDivisibleError + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import ModuleHandler, NodeHandler +from .registry import operator_registry +from .strategy import EmbeddingStrategyGenerator, StrategyGenerator + +__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] + + +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, + output_name: str) -> List[ShardingStrategy]: + """ + This function converts the logical sharding spec to the physical sharding spec for both the input and output + of the embedding operation. + + Args: + strategy (ShardingStrategy): the logical strategy generated by the strategy generator. + input_name (str): the name of the OperationData object for the input. + output_name (str): the name of the OperationData object for the output. + """ + # the result will be a list of strategies + sharding_strategies = [] + + # get operation data + input_op_data = strategy.get_op_data_by_name(input_name) + output_op_data = strategy.get_op_data_by_name(output_name) + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + # get logger for debug message + logger = get_dist_logger() + + # For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for + # logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the + # physical input shape. Thus, we enumerate to get all possible cases. + if input_sharding_spec.dim_partition_dict: + # if bool(input_sharding_spec.dim_partition_dict), it means that the + # the generated sharding strategy does shard the non-matrix dimension, + # in this case, we need to do enumeration + num_input_dims = input_op_data.data.dim() + for i in range(num_input_dims): + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + try: + # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {0: i} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + + strategy_copy.name = f'{strategy.name}_{i}' + sharding_strategies.append(strategy_copy) + + except ShardingNotDivisibleError as e: + logger.debug( + f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + ) + else: + # the generated sharding strategy does not shard the non-matrix dimension, + # in this case, we don't need to do enumeration + # but instead, we still need to convert the logical shape to physical shape + strategy_copy = strategy.clone() + input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) + + # after updating, the logical shape will be replaced by the physical shape + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={}, + physical_shape=input_op_data.data.shape, + inplace=True) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + dim_mapping = {} + + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(strategy_copy) + + return sharding_strategies + + +@operator_registry.register(torch.nn.Embedding) +class EmbeddingModuleHandler(ModuleHandler): + """ + A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape) + + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight']) + + # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies + + +@operator_registry.register(F.embedding) +class EmbeddingFunctionHandler(NodeHandler): + """ + A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # In F.embedding operation, all the dimensions of input will be treated as the batch dimension, + # and then the sharding spec will be generated based on the logical 1D tensor. + # After that, the logical sharding info will be enumerated among all the physical dimensions. + # Finally, the input will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1).shape + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data) + + # Same as input, in F.embedding operation, all the dimensions of output will be treated as + # (batch dimension, embedding dimension), and then the sharding spec will be generated based + # on the logical 2D tensor. + # After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions. + # Finally, the output will be transformed back to its original shape in self.post_process + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + return mapping + + def post_process(self, strategy: ShardingStrategy): + """ + Convert the sharding spec from the logical shape to the physical shape. + """ + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, + input_name=str( + self.node.args[0]), + output_name=str(self.node)) + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py new file mode 100644 index 000000000..15f66104b --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py @@ -0,0 +1,10 @@ +from .permute_handler import PermuteHandler +from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator +from .split_handler import SplitHandler +from .transpose_handler import TransposeHandler +from .view_handler import ViewHandler + +__all__ = [ + 'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator', + 'SplitHandler', 'SplitGenerator' +] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py new file mode 100644 index 000000000..6d625e153 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py @@ -0,0 +1,76 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import PermuteGenerator + +__all__ = ['PermuteHandler'] + + +@operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.permute) +class PermuteHandler(NodeHandler): + """ + A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + permute_dims = [] + if self.node.op == 'call_method': + # torch.Tensor.permute (input, *dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + permute_dims.append(arg._meta_data) + else: + assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + permute_dims.append(arg) + else: + # torch.permute (input, dims) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, (tuple, list)): + permute_dims.extend(arg._meta_data) + else: + assert isinstance( + arg, + (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + permute_dims.extend(arg) + + num_dims = self.node._meta_data.dim() + for i in range(num_dims): + # recover negative value to positive + if permute_dims[i] < 0: + permute_dims[i] += num_dims + + physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "permute_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py new file mode 100644 index 000000000..b7248d011 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py @@ -0,0 +1,299 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] + + +class ReshapeGenerator(FollowingStrategyGenerator): + """ + ReshapeGenerator is the base class for all the reshape operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + return super().collate_strategies() + + +class ViewGenerator(ReshapeGenerator): + """ + ViewGenerator deals with the sharding strategies of view op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + origin_shape = self.op_data['input'].data.shape + tgt_shape = self.op_data['tgt_shape'].data + + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) + + if keep_sharding_status: + dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, + reshape_mapping_dict) + else: + dim_partition_dict_for_output = {} + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + if keep_sharding_status: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + else: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + # add comm action for converting input to fully replicated + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + # the total mesh dim list only has one element, so the shard dim has only one element as well. + shard_dim = list(dim_partition_dict_for_input.keys())[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = shard_dim + # it will split the input activation grad through shard_dim during backward phase. + input_comm_action.comm_spec.shard_dim = shard_dim + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class PermuteGenerator(ReshapeGenerator): + """ + PermuteGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + permute_dims = self.op_data['permute_dims'].data + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + for dim_index, permute_dim in enumerate(permute_dims): + if permute_dim in dim_partition_dict_for_input: + dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class TransposeGenerator(ReshapeGenerator): + """ + TransposeGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + + transpose_dims = self.op_data['transpose_dims'].data + dim_0 = transpose_dims[0] + dim_1 = transpose_dims[1] + for dim, sharded_dims in dim_partition_dict_for_input.items(): + if dim == dim_0: + dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] + elif dim == dim_1: + dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] + else: + dim_partition_dict_for_output[dim] = sharded_dims + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class SplitGenerator(ReshapeGenerator): + """ + SplitGenerator deals with the sharding strategies of split op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + recover_dims = None + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + split_size, split_dim = self.op_data['split_info'].data + + if split_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(split_dim) + + dim_partition_dict_for_output = [ + copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data)) + ] + assert len(dim_partition_dict_for_output) >= 2 + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}' + + # add comm action if the input need to be recovered to replica in the split dimension. + if recover_dims: + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(recover_dims) == 1: + recover_dims = recover_dims[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=recover_dims, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = split_dim + # it will split the input activation grad through split_dim during backward phase. + input_comm_action.comm_spec.shard_dim = split_dim + + elif len(recover_dims) >= 2: + # original sharding spec + source_spec = input_sharding_spec + # target sharding spec + target_spec = sharding_spec_mapping["input"] + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py new file mode 100644 index 000000000..38c5eed7d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py @@ -0,0 +1,63 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import SplitGenerator + +__all__ = ['SplitHandler'] + + +@operator_registry.register(torch.Tensor.split) +@operator_registry.register(torch.split) +class SplitHandler(NodeHandler): + """ + A SplitHandler which deals with the sharding strategies for torch.permute or torch.split. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + split_size = self.node.args[1] + if len(self.node.args) == 3: + # (input, split_size, split_dim) + split_dim = self.node.args[2] + else: + if self.node.kwargs: + split_dim = self.node.kwargs['dim'] + else: + split_dim = 0 + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if split_dim < 0: + split_dim += num_dims + + split_info = (split_size, split_dim) + physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "split_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py new file mode 100644 index 000000000..3c7336a93 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py @@ -0,0 +1,65 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import TransposeGenerator + +__all__ = ['TransposeHandler'] + + +@operator_registry.register(torch.Tensor.transpose) +@operator_registry.register(torch.transpose) +class TransposeHandler(NodeHandler): + """ + A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + transpose_dims = [] + # torch.transpose (input, dim0, dim1) + for arg in self.node.args: + if isinstance(arg, torch.fx.Node): + if isinstance(arg._meta_data, int): + transpose_dims.append(arg._meta_data) + else: + transpose_dims.append(arg) + + num_dims = self.node._meta_data.dim() + for i in range(2): + # recover negative value to positive + if transpose_dims[i] < 0: + transpose_dims[i] += num_dims + + physical_shape_operand = OperationData(name='transpose_dims', + type=OperationDataType.ARG, + data=list(transpose_dims)) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "transpose_dims": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py new file mode 100644 index 000000000..6be634593 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py @@ -0,0 +1,53 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .reshape_generator import ViewGenerator + +__all__ = ['ViewHandler'] + + +@operator_registry.register(torch.Tensor.reshape) +@operator_registry.register(torch.reshape) +@operator_registry.register(torch.Tensor.view) +class ViewHandler(NodeHandler): + """ + A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + target_shape = self.node._meta_data.shape + physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "tgt_shape": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py new file mode 100644 index 000000000..53addb873 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py @@ -0,0 +1,34 @@ +from typing import Dict, List + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .strategy import GetattrGenerator, StrategyGenerator + +__all__ = ['GetattrHandler'] + + +class GetattrHandler(NodeHandler): + """ + A GetattrHandler which deals with the sharding strategies for Getattr Node. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(GetattrGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # There are only two possible types for get_attr node: + # 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers) + # 2. torch.nn.Module + # temporarily, we just support first case in Tracer, so we don't have to worry about + # issue related to the node._meta_data type. + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py index 25baa7766..3466e9dd9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -6,7 +6,7 @@ import torch from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler from .registry import operator_registry -from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) +from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator __all__ = ['GetItemHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 62210ebe9..37ff3c3ab 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,12 +3,16 @@ from typing import Dict, List, Union import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim +from colossalai.auto_parallel.tensor_shard.utils import ( + check_sharding_spec_validity, + transpose_partition_dim, + update_partition_dim, +) from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator @@ -28,9 +32,11 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr # switch the dimensions of the transposed weight sharding_spec = strategy.get_sharding_spec_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name) - assert op_data.logical_shape != op_data.data.shape, \ - "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" - transpose_partition_dim(sharding_spec, 0, -1) + assert op_data.logical_shape[0] == op_data.data.shape[1] and \ + op_data.logical_shape[1] == op_data.data.shape[0], \ + "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + dim_size = len(op_data.logical_shape) + transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy @@ -54,6 +60,23 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha input_op_data = strategy.get_op_data_by_name(input_name) output_op_data = strategy.get_op_data_by_name(output_name) input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_input_dims = len(input_op_data.logical_shape) - 1 + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_input_dims = input_op_data.data.dim() - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + if last_logical_input_dims in input_sharding_spec.dim_partition_dict: + input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims} + else: + input_last_dim_mapping = {} + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + output_last_dim_mapping = {} # get logger for debug message logger = get_dist_logger() @@ -73,14 +96,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) try: # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + input_dim_mapping = {0: i} + input_dim_mapping.update(input_last_dim_mapping) + update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, + dim_mapping=input_dim_mapping, physical_shape=input_op_data.data.shape, inplace=True) + output_dim_mapping = {0: i} + output_dim_mapping.update(output_last_dim_mapping) + update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={0: i}, + dim_mapping=output_dim_mapping, physical_shape=output_op_data.data.shape, inplace=True) + strategy_copy.name = f'{strategy.name}_{i}' sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( @@ -95,12 +125,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) # after updating, the logical shape will be replaced by the physical shape + input_dim_mapping = {} + input_dim_mapping.update(input_last_dim_mapping) update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={}, + dim_mapping=input_dim_mapping, physical_shape=input_op_data.data.shape, inplace=True) + + output_dim_mapping = {} + output_dim_mapping.update(output_last_dim_mapping) update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={}, + dim_mapping=output_dim_mapping, physical_shape=output_op_data.data.shape, inplace=True) sharding_strategies.append(strategy_copy) @@ -108,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha @operator_registry.register(torch.nn.Linear) -class LinearModuleHandler(ModuleHandler): +class LinearModuleHandler(MetaInfoModuleHandler): """ A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ @@ -116,7 +151,8 @@ class LinearModuleHandler(ModuleHandler): def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -167,15 +203,16 @@ class LinearModuleHandler(ModuleHandler): @operator_registry.register(F.linear) -class LinearFunctionHandler(NodeHandler): +class LinearFunctionHandler(MetaInfoNodeHandler): """ - A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. + A LinearFunctionHandler which deals with the sharding strategies for F.Linear. """ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -198,27 +235,34 @@ class LinearFunctionHandler(NodeHandler): type=data_type, data=self.node.args[1]._meta_data, logical_shape=self.node.args[1]._meta_data.shape[::-1]) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.node.args[2] is not None: + if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: # check if the other operand is a parameter - if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter): + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.args[2]), + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), type=data_type, - data=self.node.args[2]._meta_data) + data=self.node.kwargs["bias"]._meta_data) mapping['bias'] = physical_bias_operand + return mapping def post_process(self, strategy: ShardingStrategy): # switch the dimensions of the transposed weight strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name=str(self.node.args[1])) - # create multiple sharding strategies for the inputs # as input can be multi-dimensinal and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py new file mode 100644 index 000000000..d3f9fd01d --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -0,0 +1,486 @@ +import operator +from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum +from functools import reduce +from typing import Dict, List, Union + +import torch + +from colossalai.auto_parallel.tensor_shard.utils.broadcast import ( + BroadcastType, + get_broadcast_dim_info, + get_broadcast_shape, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, + StrategyGenerator, +) + + +class MatMulType(Enum): + """ + The MatMulType is categorized into 4 types based on the reference of torch.matmul + in https://pytorch.org/docs/stable/generated/torch.matmul.html. + + DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements + MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D + MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D + BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D + """ + DOT = 0 + MM = 1 + MV = 2 + BMM = 3 + + +def get_matmul_type(input_dim: int, other_dim: int): + """ + Determine which type of matmul operation should be executed for the given tensor dimensions. + + Args: + input_dim (int): the number of dimensions for the input tenosr + other_dim (int): the number of dimensions for the other tenosr + """ + if input_dim == 1 and other_dim == 1: + matmul_type = MatMulType.DOT + elif input_dim in [1, 2] and other_dim == 2: + matmul_type = MatMulType.MM + elif input_dim == 2 and other_dim == 1: + matmul_type = MatMulType.MV + elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2): + matmul_type = MatMulType.BMM + else: + raise ValueError( + f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation" + ) + return matmul_type + + +class BmmTransform(ABC): + """ + BmmTransform is an abstraction of the shape conversion between logical and physical operation data + during the strategy generation. + """ + + @abstractmethod + def apply(self, shape_mapping: Dict[str, List[int]]): + pass + + @abstractmethod + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + pass + + +class Padder(BmmTransform): + """ + Add padding to the matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + # keep the padding dim, op_name -> padded_dim + self.padded_dim_mapping = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = deepcopy(shape_mapping) + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + if len(input_shape) == 1: + # if the input is a 1D tensor, 1 is prepended to its shape + # and it will be removed afterwards + input_shape.insert(0, 1) + self.padded_dim_mapping['input'] = -2 + self.padded_dim_mapping['output'] = -2 + elif len(other_shape) == 1: + # if the other is a 1D tensor, 1 is appended to its shape + # and it will be removed afterwards + other_shape = other_shape.append(1) + self.padded_dim_mapping['other'] = -1 + self.padded_dim_mapping['output'] = -1 + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + input_op_data = op_data_mapping['input'] + other_op_data = op_data_mapping['other'] + + def _remove_padded_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + dim_partition_list = [None] * len(tensor_shape) + + # padded dim is a negative number as the padded dim must be a matrix dim + padded_dim = self.padded_dim_mapping[key] + + # compute the new dim partition + for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items(): + dim_partition_list[tensor_dim] = mesh_dims + dim_partition_list.pop(padded_dim) + unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None} + + # compute unpadded tensor shape + tensor_shape.pop(padded_dim) + + assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + + # update sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + + # only one of input and other will be padded + if 'input' in self.padded_dim_mapping: + _remove_padded_dim('input', strategy_copy) + _remove_padded_dim('output', strategy_copy) + elif 'other' in self.padded_dim_mapping: + _remove_padded_dim('other', strategy_copy) + _remove_padded_dim('output', strategy_copy) + + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Broadcaster(BmmTransform): + """ + Broadcast the non-matrix dimensions for batched matrix multiplication. + """ + + def __init__(self) -> None: + self.broadcast_dim_info = {} + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + + # get shapes + input_shape = mapping_copy['input'] + other_shape = mapping_copy['other'] + + # sanity check + assert len(input_shape) > 1 and len(other_shape) > 1 + + # broadcast the batch dim and record + bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2]) + + # store the broadcast dim info + input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) + other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) + self.broadcast_dim_info['input'] = input_broadcast_dim_info + self.broadcast_dim_info['other'] = other_broadcast_dim_info + + # create the full logical shape + input_shape = bcast_non_matrix_dims + input_shape[-2:] + other_shape = bcast_non_matrix_dims + other_shape[-2:] + assert len(input_shape) == len(other_shape) + + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # remove sharding on the broadcast dim + def _remove_sharding_on_broadcast_dim(key, strategy): + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + tensor_shape = list(sharding_spec.entire_shape) + + for dim_idx, broadcast_type in self.broadcast_dim_info[key].items(): + if broadcast_type == BroadcastType.MULTIPLE: + # if the dim is originally 1 and multiplied during broadcast + # we set its sharding to R + # e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8] + # the dim 0 of [1, 2, 4] is multiplied to 4 + tensor_shape[dim_idx] = 1 + elif broadcast_type == BroadcastType.PADDDING: + # if the dim is padded + # we remove its sharding + tensor_shape[dim_idx] = None + + tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] + + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec=sharding_spec, + logical_shape=sharding_spec.entire_shape, + physical_shape=tensor_shape_before_broadcast) + strategy.sharding_specs[op_data] = physical_sharding_spec + + # enumerate all sharding strategies + strategies = [] + try: + strategy_copy = strategy.clone() + _remove_sharding_on_broadcast_dim('input', strategy_copy) + _remove_sharding_on_broadcast_dim('other', strategy_copy) + strategies.append(strategy_copy) + except ShardingSpecException as e: + pass + return strategies + + +class Viewer(BmmTransform): + """ + Change the shape of the tensor from N-D to 3D + """ + + def __init__(self) -> None: + self.batch_dims_before_view = None + + def apply(self, shape_mapping: Dict[str, List[int]]): + mapping_copy = shape_mapping.copy() + self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + + # get shapes + input_shape = shape_mapping['input'] + other_shape = shape_mapping['other'] + + # view to 3d tensor + assert len(input_shape) >= 3 and len(other_shape) >= 3 + input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] + other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] + output_shape = input_shape[:2] + other_shape[2:] + mapping_copy['input'] = input_shape + mapping_copy['other'] = other_shape + mapping_copy['output'] = output_shape + return mapping_copy + + def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): + # get operation data + def _update_sharding_spec(key, strategy, physical_batch_dim): + """ + Map the logical batch dim to the physical batch dim + """ + op_data = op_data_mapping[key] + sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) + dim_partition_dict = sharding_spec.dim_partition_dict + entire_shape = sharding_spec.entire_shape + + # upddate the dimension index for the matrix dimensions + if 2 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) + if 1 in dim_partition_dict: + dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) + + # map the logical batch dim to phyiscal batch dim + if 0 in dim_partition_dict: + batch_dim_shard = dim_partition_dict.pop(0) + dim_partition_dict[physical_batch_dim] = batch_dim_shard + + # the new shape will be the batch dims + the last 2 matrix dims + shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:]) + sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict) + + num_batch_dim_before_view = len(self.batch_dims_before_view) + + # enumerate all sharding strategies + strategies = [] + for i in range(num_batch_dim_before_view): + # create a new strategy + strategy_copy = strategy.clone() + try: + _update_sharding_spec('input', strategy_copy, i) + _update_sharding_spec('other', strategy_copy, i) + _update_sharding_spec('output', strategy_copy, i) + strategies.append(strategy_copy) + except ShardingSpecException as e: + continue + return strategies + + +def _get_bmm_logical_shape(input_shape, other_shape, transforms): + """ + Compute the logical shapes for BMM operation. BMM has a general representation + [b, i, k] = [b, i, j] x [b, j, k] + + The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions + The logical shape for the bmm operands will undergo three stages + 1. append/prepend the 1 to the 1D tensor if there is any + 2. broadcast the non-matrix dimensions + 3. reshape to 3 dimensions + + """ + shape_mapping = {'input': input_shape, 'other': other_shape} + + for transform in transforms: + shape_mapping = transform.apply(shape_mapping) + + input_shape = shape_mapping.get('input', None) + other_shape = shape_mapping.get('other', None) + output_shape = shape_mapping.get('output', None) + + return input_shape, other_shape, output_shape + + +@operator_registry.register(torch.matmul) +@operator_registry.register(torch.Tensor.matmul) +class MatMulHandler(NodeHandler): + """ + The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation. + According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on + the operands. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # check which type of operation this matmul will call + self.input_meta_data = self.node.args[0]._meta_data + self.other_meta_data = self.node.args[1]._meta_data + self.output_meta_data = self.node._meta_data + + input_dim = self.input_meta_data.dim() + other_dim = self.other_meta_data.dim() + self.matmul_type = get_matmul_type(input_dim, other_dim) + + if self.matmul_type == MatMulType.BMM: + # bmm operation can possibly involve padding, broadcasting and view + # these transforms will be used to create logical shape and + # recover physical sharding spec + self.transforms = [Padder(), Broadcaster(), Viewer()] + else: + self.transforms = None + + def get_strategy_generator(self) -> List[StrategyGenerator]: + generators = [] + op_data_mapping = self.get_operation_data_mapping() + if self.matmul_type == MatMulType.BMM: + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.DOT: + generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MV: + generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) + elif self.matmul_type == MatMulType.MM: + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + logical_shape_func = { + MatMulType.DOT: self._get_logical_shape_for_dot, + MatMulType.MM: self._get_logical_shape_for_mm, + MatMulType.MV: self._get_logical_shape_for_mv, + MatMulType.BMM: self._get_logical_shape_for_bmm + } + logical_shapes = logical_shape_func[self.matmul_type]() + op_data_mapping = self._get_op_data_mapping(*logical_shapes) + return op_data_mapping + + def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape): + # convert list to torch.Size + if input_logical_shape: + input_logical_shape = torch.Size(input_logical_shape) + + if other_logical_shape: + other_logical_shape = torch.Size(other_logical_shape) + + if output_logical_shape: + output_logical_shape = torch.Size(output_logical_shape) + + # create op data + input_op_data = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape) + other_op_data = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape) + output_op_data = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape) + + mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + return mapping + + def _get_logical_shape_for_dot(self): + """ + The operands for the dot operation have the same logical shape as the physical shape + """ + return None, None, None + + def _get_logical_shape_for_mm(self): + """ + We need to handle the input tensor for a matrix-matrix multiplcation as the input + tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape + (e.g. [4] -> [1, 4]). + """ + if self.input_meta_data.dim() == 1: + input_logical_shape = [1] + list(self.input_meta_data.shape) + input_logical_shape = torch.Size(input_logical_shape) + else: + input_logical_shape = None + return input_logical_shape, None, None + + def _get_logical_shape_for_mv(self): + """ + No broadcasting or dim insertion occurs for matrix-vector operation. + """ + return None, None, None + + def _get_logical_shape_for_bmm(self): + input_physical_shape = list(self.input_meta_data.shape) + other_physical_shape = list(self.other_meta_data.shape) + return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms) + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + if self.matmul_type in [MatMulType.DOT, MatMulType.MV]: + return strategy + elif self.matmul_type == MatMulType.MM: + if self.input_meta_data.dim() == 1: + # if a 1 is prepended to the input shape (this occurs when input is a 1D tensor) + # we need to remove that dim + input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0])) + input_physical_shape = self.node.args[0]._meta_data.shape + dim_partition_dict = input_sharding_spec.dim_partition_dict + + # remove the partitioning in the dim 0 + if 0 in dim_partition_dict: + dim_partition_dict.pop(0, None) + + # move the partitioning in dim 1 to dim 0 + if -1 in dim_partition_dict: + shard = dim_partition_dict.pop(-1) + dim_partition_dict[0] = shard + if 1 in dim_partition_dict: + shard = dim_partition_dict.pop(1) + dim_partition_dict[0] = shard + + # re-init the sharding spec + input_sharding_spec.__init__(input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict) + return strategy + else: + return strategy + elif self.matmul_type == MatMulType.BMM: + op_data_mapping = self.get_operation_data_mapping() + + strategies = [strategy] + # recover the physical sharding spec + for transform in self.transforms[::-1]: + recovered_stragies = [] + for strategy_ in strategies: + output = transform.recover(op_data_mapping, strategy_) + if isinstance(output, ShardingStrategy): + recovered_stragies.append(output) + elif isinstance(output, (list, tuple)): + recovered_stragies.extend(output) + else: + raise TypeError( + f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + strategies = recovered_stragies + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 8d9683766..78dc58c90 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, + OperationDataType, + ShardingSpec, ShardingStrategy, StrategiesVector, TrainCycleItem, @@ -49,7 +52,16 @@ class NodeHandler(ABC): for node in self.predecessor_node: node_name = str(node) + # get the current sharding spec generated by this node handler + # we will not compute the resharding costs for the node not counted in the strategy. + # And the node with tuple or list output need to be handled below. + node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()] + if str(node) not in node_in_strategy: + continue + + op_data = strategy.get_op_data_by_name(node_name) + current_sharding_spec = strategy.sharding_specs[op_data] # get the sharding specs for this node generated # in its own node handler assert hasattr(node, 'strategies_vector'), \ @@ -59,27 +71,83 @@ class NodeHandler(ABC): prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector ] - # get the current sharding spec generated by this node handler - op_data = strategy.get_op_data_by_name(node_name) - current_sharding_spec = strategy.sharding_specs[op_data] - # create data structrure to store costs - if op_data not in resharding_costs: + if node not in resharding_costs: resharding_costs[node] = [] + def _compute_resharding_cost( + prev_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + """ + This is a helper function to compute the resharding cost for a specific strategy of a node. + """ + if prev_sharding_spec is None: + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(prev_sharding_spec, ShardingSpec): + if isinstance(data, torch.Tensor): + dtype = data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + _, _, consistency_cost = shape_consistency_manager.shape_consistency( + prev_sharding_spec, current_sharding_spec) + + resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes) + return resharding_cost + else: + # This raise is used to check if we have missed any type of data. + # It could be merged into Parameter branch, which means we won't handle + # non-tensor arguments. + raise ValueError(f'Unsupported data type {type(data)}') + else: + assert isinstance(prev_sharding_spec, (tuple, list)), \ + f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + + fwd_cost = 0 + bwd_cost = 0 + total_cost = 0 + for index, (prev_sharding_spec_item, + current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, + current_sharding_spec)): + item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, + data[index]) + fwd_cost += item_cost.fwd + bwd_cost += item_cost.bwd + total_cost += item_cost.total + resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost) + return resharding_cost + # for each sharding spec generated by the predecessor's node handler # compute the resharding cost to switch to the sharding spec generated # by the current node handler for prev_sharding_spec in prev_sharding_specs: - _, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, - current_sharding_spec) - resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"], - bwd=resharding_cost["backward"], - total=resharding_cost["total"]) + resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data) resharding_costs[node].append(resharding_cost) strategy.resharding_costs = resharding_costs return strategy + def get_target_function(self) -> callable: + """ + This function is used to get the target function for the node handler. + The target function is used to analyze the costs of strategies. + """ + if self.node.op in ('placeholder', 'get_attr', 'output'): + return None + + if self.node.op == 'call_module': + target = self.node.graph.owning_module.get_submodule(self.node.target) + elif self.node.op == 'call_function': + target = self.node.target + elif self.node.op == 'call_method': + target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + else: + raise ValueError(f'Unsupported node type: {self.node.op}') + + return target + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ Register different sharding strategies for the current node. @@ -151,6 +219,38 @@ class NodeHandler(ABC): pass +class MetaInfoNodeHandler(NodeHandler): + """ + This is a base class to handle the nodes patched in the meta profiler. + + Note: this class will be integrated into the NodeHandler class in the future, after + all the functions are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) + + return self.strategies_vector + + class ModuleHandler(NodeHandler): def __init__(self, *args, **kwargs) -> None: @@ -168,3 +268,35 @@ class ModuleHandler(NodeHandler): self.module = module self.named_parameters = named_parameters self.named_buffers = named_buffers + + +class MetaInfoModuleHandler(ModuleHandler): + """ + This is a base class to handle the module patched in the meta profiler. + + Note: this class will be integrated into the ModuleHandler class in the future, after + all the modules are patched. + """ + + def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: + """ + This method is inherited from NodeHandler. It will register the strategies first, + and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + """ + super().register_strategy(compute_resharding_cost=compute_resharding_cost) + target = self.get_target_function() + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) + + return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 1509c05a3..4e71ccba9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import ModuleHandler +from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry from .strategy import NormalPoolStrategyGenerator, StrategyGenerator @@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler'] @operator_registry.register(torch.nn.AvgPool1d) @operator_registry.register(torch.nn.AvgPool2d) @operator_registry.register(torch.nn.AvgPool3d) -class NormPoolingHandler(ModuleHandler): +class NormPoolingHandler(MetaInfoModuleHandler): """ A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index 489e40daf..ed120a8c3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -2,38 +2,51 @@ from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator -__all__ = ['OuputHandler'] +__all__ = ['OutputHandler'] -class OuputHandler(NodeHandler): +class OutputHandler(NodeHandler): """ - A OuputHandler which deals with the sharding strategies for Output Node. + A OutputHandler which deals with the sharding strategies for Output Node. """ + def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + output_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.output_option = output_option + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node)) + generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - dummy_output = torch.empty(1,).to("meta") - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output) - - mapping = {"output": physical_output} + mapping = {} + output_meta_data = [] for index, input_node in enumerate(self.predecessor_node): - if not hasattr(input_node, "_meta_data"): - print(input_node.name) - physical_inputs = OperationData(name=str(input_node), - type=OperationDataType.ARG, - data=input_node._meta_data) + input_meta_data = input_node._meta_data + physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) name_key = f'input_{index}' mapping[name_key] = physical_inputs + output_meta_data.append(input_meta_data) + assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + if len(output_meta_data) == 1: + output_meta_data = output_meta_data[0] + else: + output_meta_data = tuple(output_meta_data) + + self.node._meta_data = output_meta_data + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping["output"] = physical_output return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index 88a02428e..e4f40fc93 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -1,21 +1,31 @@ from typing import Dict, List -from ..sharding_strategy import OperationData, OperationDataType +from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator -__all__ = ['PlacehodlerHandler'] +__all__ = ['PlaceholderHandler'] -class PlacehodlerHandler(NodeHandler): +class PlaceholderHandler(NodeHandler): """ - A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. + A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. """ + def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + placeholder_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.placeholder_option = placeholder_option + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh)) + generators.append( + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py index 6bed842d4..8e06cec4f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -8,7 +8,12 @@ class Registry: def register(self, source): def wrapper(func): - self.store[source] = func + if isinstance(source, (list, tuple)): + # support register a list of items for this func + for element in source: + self.store[element] = func + else: + self.store[source] = func return func return wrapper diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 402485352..7763b1884 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,18 +3,17 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator __all__ = ['ReshapeHandler'] -@operator_registry.register(torch.reshape) @operator_registry.register(torch.flatten) -@operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ @@ -25,13 +24,47 @@ class ReshapeHandler(NodeHandler): generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) return generators + def infer_logical_shape(self, data): + """ + This function is used to infer logical shape for operands. + + Notes: This function is only used for the operands whose data are not only in type of tensor, + such as tuple of tensor. + """ + if isinstance(data, torch.Tensor): + return data.shape + else: + assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor." + logical_shape = [] + for tensor in data: + assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor." + logical_shape.append(tensor.shape) + logical_shape = tuple(logical_shape) + return logical_shape + def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + input_logical_shape = self.infer_logical_shape(input_data) physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + type=data_type, + data=input_data, + logical_shape=input_logical_shape) + + output_data = self.node._meta_data + output_logical_shape = self.infer_logical_shape(output_data) + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_data, + logical_shape=output_logical_shape) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py new file mode 100644 index 000000000..743a1f90e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SoftmaxGenerator, StrategyGenerator + +__all__ = ['SoftmaxHandler'] + + +@operator_registry.register(torch.nn.Softmax) +@operator_registry.register(torch.nn.functional.softmax) +class SoftmaxHandler(NodeHandler): + """ + A SoftmaxHandler which deals with the sharding strategies for + torch.nn.Softmax or torch.nn.functional.softmax. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + softmax_dim = self.node.kwargs['dim'] + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if softmax_dim < 0: + softmax_dim += num_dims + + physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "softmax_dim": physical_dim_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index f137f09db..8d25475f9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -1,14 +1,24 @@ from .batch_norm_generator import BatchNormStrategyGenerator +from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator -from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator) +from .embedding_generator import EmbeddingStrategyGenerator +from .getattr_generator import GetattrGenerator +from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator -from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator, - LinearProjectionStrategyGenerator, MatVecStrategyGenerator) +from .matmul_strategy_generator import ( + BatchedMatMulStrategyGenerator, + DotProductStrategyGenerator, + LinearProjectionStrategyGenerator, + MatVecStrategyGenerator, +) from .normal_pooling_generator import NormalPoolStrategyGenerator from .output_generator import OutputGenerator from .placeholder_generator import PlaceholderGenerator from .reshape_generator import ReshapeGenerator +from .softmax_generator import SoftmaxGenerator from .strategy_generator import StrategyGenerator +from .sum_generator import SumGenerator +from .tensor_constructor_generator import TensorConstructorGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator from .where_generator import WhereGenerator @@ -17,5 +27,6 @@ __all__ = [ 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', - 'ReshapeGenerator', 'NormalPoolStrategyGenerator' + 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', + 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index e648fff39..1f3812429 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -3,7 +3,13 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -98,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def split_input_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_mapping = { @@ -129,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_mapping = { @@ -160,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def non_split(self): name = f'RR = RR x R' dim_partition_dict_mapping = { @@ -181,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' dim_partition_dict_mapping = { @@ -204,17 +214,21 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_0, + comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_spec} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' dim_partition_dict_mapping = { @@ -238,17 +252,21 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for gradients of weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_spec} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' dim_partition_dict_mapping = { @@ -282,12 +300,15 @@ class BatchNormStrategyGenerator(StrategyGenerator): # For SyncBN case, we don't need to do communication for gradients of weight and bias. # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0]) + logical_process_axis=[mesh_dim_0], + comm_type=CommType.IMPLICIT) - communication_action_mapping = {"output": output_comm_spec} + # TODO: Temporary solution has no communication cost, + # above action should be added after the SyncBN replace pass completed. + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -316,14 +337,14 @@ class BatchNormStrategyGenerator(StrategyGenerator): # TODO: The strategies below should be uncommented after runtime # passes ready. # SR = SR x R WITH SYNC_BN - # strategy_list.append(self.split_input_batch(0)) - # strategy_list.append(self.split_input_batch(1)) + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) # SS = SS x S WITH SYNC_BN - # strategy_list.append(self.split_input_both_dim(0, 1)) - # strategy_list.append(self.split_input_both_dim(1, 0)) + strategy_list.append(self.split_input_both_dim(0, 1)) + strategy_list.append(self.split_input_both_dim(1, 0)) # S01R = S01R x R WITH SYNC_BN - # strategy_list.append(self.split_input_batch_1d(0, 1)) + strategy_list.append(self.split_input_batch_1d(0, 1)) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py new file mode 100644 index 000000000..fd7f811c8 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -0,0 +1,111 @@ +import operator +from functools import reduce +from typing import List + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['BinaryElementwiseStrategyGenerator'] + + +class BinaryElementwiseStrategyGenerator(StrategyGenerator): + """ + An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations + which have two operands and broadcasting occurs such as torch.add. + + The logical shape for this operation will be `input other`. + """ + + def validate(self) -> bool: + assert len(self.op_data) == 3, \ + f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + for name, op_data in self.op_data.items(): + if not isinstance(op_data.data, (torch.Tensor, int, float)): + raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # since elementwise ops are not compute-intensive, + # we approximate the backward compute cost + # to be twice the fwd compute cost + fwd_compute_cost = reduce(operator.mul, shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + # all input, output and outputs have the same shape + shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + # compute fwd memory cost in bytes + # as the elementwise ops are not memory-intensive + # we approximate the fwd memroy cost to be the output + # and the backward memory cost to be grad of input and other + input_bytes = self._compute_size_in_bytes(strategy, 'input') + other_bytes = self._compute_size_in_bytes(strategy, 'other') + output_bytes = self._compute_size_in_bytes(strategy, 'output') + fwd_memory_cost = MemoryCost(activation=output_bytes) + bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) + total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(input=dim_partition_dict, + other=dim_partition_dict, + output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + sharding_seq = sharding_spec_mapping['input'].sharding_sequence + name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = self.enumerate_all_possible_output(0, 1) + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index 83476e4fe..c2154b310 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -4,7 +4,6 @@ import warnings from functools import reduce from typing import List - from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) - -from colossalai.auto_parallel.tensor_shard.utils import \ - ignore_sharding_exception - +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.BEFORE) + comm_type=CommType.BEFORE, + arg_index=0) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -144,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -183,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -223,8 +254,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} @@ -234,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping["other"] = other_comm_action + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -277,12 +322,11 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0, + logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, arg_index=0) @@ -316,8 +360,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} @@ -351,7 +394,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.BEFORE) + comm_type=CommType.BEFORE, + arg_index=0) communication_action_mapping = {"input": input_comm_action} @@ -404,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -441,8 +501,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py new file mode 100644 index 000000000..82a04ab52 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -0,0 +1,310 @@ +import copy +import operator +import warnings +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +from .strategy_generator import StrategyGenerator + + +class EmbeddingStrategyGenerator(StrategyGenerator): + """ + EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding. + The operation data is defined as `output = input x other`. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + + Note: The computation cost for the embedding handler is estimated as dense computing now. + It may not be accurate. + ''' + # TODO: estimate the embedding computation cost as sparse operation + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + + input_size_product = reduce(operator.mul, sharded_input_shape) + other_size_product = reduce(operator.mul, sharded_other_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = input_size_product * other_size_product + + backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1] + backward_weight_cost = input_size_product * other_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def non_split(self): + name = f'RR = R x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + @ignore_sharding_exception + def split_input(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + }, + "other": { + 1: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping = {"input": input_comm_action} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + communication_action_mapping = {} + + if self.is_param("other"): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_embedding_dim(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_action = self.get_communication_action( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + + communication_action_mapping = {"input": input_comm_action} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR= R x RR + strategies.append(self.non_split()) + + # SR = S x RR + strategies.append(self.split_input(0)) + strategies.append(self.split_input(1)) + + # SS = S x RS + strategies.append(self.split_input_and_embedding_dim(0, 1)) + strategies.append(self.split_input_and_embedding_dim(1, 0)) + + # S01R = S01 x RR + strategies.append(self.split_1d_parallel_on_input(0, 1)) + + # RS = R x RS + strategies.append(self.split_embedding_dim(0)) + strategies.append(self.split_embedding_dim(1)) + + # RS01 = R x RS01 + strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1)) + + return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py new file mode 100644 index 000000000..bbeb9a639 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -0,0 +1,89 @@ +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException + +from .strategy_generator import StrategyGenerator + +__all__ = ['GetattrGenerator'] + + +class GetattrGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}" + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + + return strategy_list + + def collate_strategies(self) -> List[ShardingStrategy]: + return self.enumerate_all_possible_output(0, 1) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 8b8080b75..0aeb2e0d4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,8 +1,15 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.logging import get_dist_logger from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import FollowingStrategyGenerator @@ -64,37 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - for strategy in self.predecessor_node.strategies_vector: - dim_partition_dict_mapping = {} - communication_action_mapping = {} - dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict - dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) - gather_input = 0 in dim_partition_dict_for_input - if gather_input: - logical_process_axis = dim_partition_dict_for_output.pop(0) + getitem_index = self.op_data['index'].data + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + try: + logger = get_dist_logger() + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_for_input = copy.deepcopy( + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) - shift_dim_partition_dict_for_output = {} - for dim, mesh_dim_list in dim_partition_dict_for_output.items(): - shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list - dim_partition_dict_for_output = shift_dim_partition_dict_for_output - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - if gather_input: - input_communication_spec = self.get_communication_spec( - sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=logical_process_axis) - communication_action_mapping["input"] = input_communication_spec + int_index = False + if isinstance(getitem_index, int): + int_index = True + getitem_dims = [ + 0, + ] + shift_length = 1 + elif isinstance(getitem_index, slice): + getitem_dims = [ + 0, + ] + else: + getitem_dims = [i for i in range(len(getitem_index))] + if isinstance(getitem_index[0], int): + int_index = True + shift_length = len(getitem_index) - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + gather_dims = [] + for dim in getitem_dims: + if dim in dim_partition_dict_for_input: + gather_dims.append(dim) - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + for dim in gather_dims: + dim_partition_dict_for_input.pop(dim) + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + if int_index: + shift_dim_partition_dict_for_output = {} + for dim, mesh_dim_list in dim_partition_dict_for_output.items(): + shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list + dim_partition_dict_for_output = shift_dim_partition_dict_for_output + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + except ShardingSpecException as e: + logger.debug(e) + continue strategy_list.append(strategy) for strategy in strategy_list: @@ -114,7 +145,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): strategy_list = [] index = self.op_data["index"].data - for strategy in self.predecessor_node.strategies_vector: + for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector): # the sharding spec for input in this case is a tuple of ShardingSpec. sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict @@ -125,8 +156,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping["input"] = sharding_spec_for_input - - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + input_sharding_info = f"get the {index} element from (" + for sharding_spec in sharding_spec_for_input: + input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += ")" + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 8c7d11437..fbb6070f7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -3,9 +3,17 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -87,6 +95,7 @@ class LayerNormGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def _generate_strategy_with_dim_partition(self, dim_partition): dim_partition_dict_mapping = { "input": dim_partition, @@ -107,18 +116,20 @@ class LayerNormGenerator(StrategyGenerator): total_mesh_dim_list = total_mesh_dim_list[0] communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=total_mesh_dim_list) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=total_mesh_dim_list) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -142,6 +153,7 @@ class LayerNormGenerator(StrategyGenerator): strategy_list.append(strategy) return strategy_list + @ignore_sharding_exception def non_split(self): name = f'RR = RR x R' dim_partition_dict_mapping = { diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index be2a95098..fa2246f95 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,8 +1,14 @@ import operator +from ast import arg from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -54,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] - bwd_compute_cost = sharded_input_shape * 2 + bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) return compute_cost + @ignore_sharding_exception def no_split(self): name = f'R = R dot R' dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} @@ -69,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_one_dim(self, mesh_dim): name = f'R = S{mesh_dim} dot S{mesh_dim}' @@ -77,16 +85,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping = {"output": output_comm_spec} + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # do not split dimensions for dot product @@ -106,38 +115,86 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] - assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 + assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + @ignore_sharding_exception def no_split(self): name = "R = R x R" - dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} + dim_partition_dict = {"input": {}, "other": {}, "output": {}} + + if self.has_bias: + dim_partition_dict['bias'] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}) + @ignore_sharding_exception def split_input_batch(self, mesh_dim): name = f'S{mesh_dim}R = S{mesh_dim}R x R' # get sharding spec - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}} + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": {}, + "output": { + 0: [mesh_dim] + }, + } + + if self.has_bias: + dim_partition_dict['bias'] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action - other_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['other'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} + communication_action_mapping = {} + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=2) + communication_action_mapping['bias'] = bias_comm_action + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy]: + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] # no split @@ -152,6 +209,10 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): + def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): + super().__init__(operation_data_mapping, device_mesh) + self.linear_projection_type = linear_projection_type + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C = AB # C: [M, N], A: [M, P], B: [P, N] @@ -202,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + # RR = RR x RR + strategies.append(self.non_split()) + return strategies @ignore_sharding_exception @@ -215,36 +279,66 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "other": { -1: [mesh_dim_1] }, - "bias": { - -1: [mesh_dim_1] - }, "output": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - other_comm_spec = self.get_communication_spec( - sharding_spec_mapping["output"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping['input'] = input_comm_spec - communication_action_mapping['other'] = other_comm_spec - - if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec_mapping["bias"], + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping['input'] = input_comm_action + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -269,28 +363,61 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action mapping communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - output_comm_spec = self.get_communication_spec( + + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) - communication_action_mapping['input'] = input_comm_spec - communication_action_mapping['output'] = output_comm_spec + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=1) - if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["bias"], - communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['bias'] = bias_comm_spec + communication_action_mapping['other'] = other_comm_action + communication_action_mapping['output'] = output_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -316,20 +443,27 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim_1] }, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) - input_comm_spec = self.get_communication_spec( + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER) + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping["input"] = input_comm_spec - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping["input"] = input_comm_action + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -349,17 +483,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "bias": {}, "output": {}, } - + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim) + logical_process_axis=mesh_dim, + comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_spec + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -381,17 +517,20 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim] }, } - + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping['input'] = input_comm_spec + communication_action_mapping['input'] = input_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -410,22 +549,52 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0, mesh_dim_1] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['other'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['other'] = other_comm_spec - - if self.has_bias: - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], + if self.is_param('other'): + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + other_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action + + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -445,15 +614,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "bias": {}, "output": {}, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -476,15 +649,43 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim_0, mesh_dim_1] }, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['input'] = input_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -500,10 +701,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] - # check if bias has the same a valid dim - has_bias = "bias" in self.op_data - - if has_bias: + if self.has_bias: bias_data = self.op_data['bias'] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] @@ -516,8 +714,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): [b, i, k] x [b, k, j] -> [b, i, j] The bias term is considered to have a 2D logical shape. + + Note: This class will be used to generate strategies for torch.bmm + and torch.addbmm. However, the result of torch.addbmm is not correct, + some extra runtime apply actions are required to keep numerical correctness. """ + # TODO: torch.addbmm correctness issue need to be fixed. def __init__(self, *args, **kwargs): self.squeeze_batch_dim = False super().__init__(*args, **kwargs) @@ -537,7 +740,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] - assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3 + assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 if 'bias' in self.op_data: bias_op_data = self.op_data['bias'] @@ -566,16 +769,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - print(sharding_spec_mapping) - # get communication actions communication_action_mapping = {} if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -602,11 +805,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -637,18 +842,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['other'] = other_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping['other'] = other_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the third argument instead of second. + communication_action_mapping['other'].arg_index += 1 return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -679,18 +890,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['input'] = input_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['input'] = input_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE) + communication_action_mapping['bias'] = bias_comm_action + # for addbmm case, other is the second argument instead of first. + communication_action_mapping['input'].arg_index += 1 return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -702,11 +918,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): dim_partition_dict = { "input": { 0: [mesh_dim_0], - -1: [mesh_dim_1] + 2: [mesh_dim_1] }, "other": { 0: [mesh_dim_0], - -2: [mesh_dim_1] + 1: [mesh_dim_1] }, "bias": {}, "output": { @@ -719,18 +935,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping['output'] = output_comm_spec + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER) + communication_action_mapping['output'] = output_comm_action if self.has_bias: - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping['bias'] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) + communication_action_mapping['bias'] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -771,6 +990,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # split two batch dim strategy_list.append(self.split_two_batch_dim(0, 1)) - strategy_list.append(self.split_two_batch_dim(1, 0)) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index 457f51450..9df6d2fbf 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -3,9 +3,12 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) from .strategy_generator import StrategyGenerator @@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def _generate_strategy_with_dim_partition(self, dim_partition): dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index de9dfba67..69d1642d4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -1,6 +1,14 @@ -from typing import List +from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import OutputStrategyGenerator @@ -12,6 +20,11 @@ class OutputGenerator(OutputStrategyGenerator): OutputGenerator is a generic class to generate strategies for Output Node. """ + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node], output_option: str): + super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) + self.output_option = output_option + def validate(self) -> bool: return super().validate() @@ -32,13 +45,27 @@ class OutputGenerator(OutputStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def collate_strategies(self) -> List[ShardingStrategy]: - dim_partition_dict_mapping = { - "output": {}, - } + def replica_strategy(self) -> List[ShardingStrategy]: + """ + Generate replica strategy for output node. + """ + dim_partition_dict_mapping = {} + dim_partition_dict_for_output = [] for index, _ in enumerate(self.predecessor_nodes): mapping_name = f"input_{index}" - dim_partition_dict_mapping[mapping_name] = {} + if isinstance(self.op_data[mapping_name].data, (tuple, list)): + dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))] + else: + dim_partition_dict_for_input = {} + dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input + dim_partition_dict_for_output.append(dim_partition_dict_for_input) + + if len(dim_partition_dict_for_output) == 1: + dim_partition_dict_for_output = dim_partition_dict_for_output[0] + else: + dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) + + dim_partition_dict_mapping['output'] = dim_partition_dict_for_output communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -48,5 +75,47 @@ class OutputGenerator(OutputStrategyGenerator): strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + return strategy - return [strategy] + def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: + """ + Generate distributed strategy for output node. + """ + # TODO: need to take care of the case when the first element of output only need to be sharded. + output_op_data = self.op_data['output'] + if isinstance(output_op_data.data, tuple): + length = len(output_op_data.data) + dim_partition_dict_mapping = { + "output": [{ + 0: mesh_list + }] * length, + } + else: + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + dim_partition_dict_mapping[mapping_name] = {0: mesh_list} + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + mesh_list = [0, 1] + if self.output_option == 'replicated': + strategy_list.append(self.replica_strategy()) + elif self.output_option == 'distributed': + strategy_list.append(self.distributed_strategy(mesh_list)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 9023ab0fb..779a7ced9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -1,6 +1,12 @@ -from typing import List +from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import StrategyGenerator @@ -12,6 +18,11 @@ class PlaceholderGenerator(StrategyGenerator): PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + placeholder_option: str): + super().__init__(operation_data_mapping, device_mesh) + self.placeholder_option = placeholder_option + def validate(self) -> bool: return super().validate() @@ -37,7 +48,10 @@ class PlaceholderGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def collate_strategies(self) -> List[ShardingStrategy]: + def replica_placeholder(self) -> ShardingStrategy: + """ + Generate replica strategy for placeholder node. + """ dim_partition_dict_mapping = { "output": {}, } @@ -50,4 +64,37 @@ class PlaceholderGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - return [strategy] + return strategy + + def distributed_placeholder(self, mesh_list) -> ShardingStrategy: + """ + Generate distributed strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + if self.placeholder_option == 'distributed': + mesh_list = [0, 1] + distributed_strategy = self.distributed_placeholder(mesh_list) + strategy_list.append(distributed_strategy) + else: + assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + replicated_strategy = self.replica_placeholder() + strategy_list.append(replicated_strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index cbe0f0746..0b3506c27 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator): arg_index=0) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list - else: + elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] target_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, @@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator): comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) - communication_action_mapping["input"] = input_comm_action + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py new file mode 100644 index 000000000..a1ebadd04 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -0,0 +1,104 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +__all__ = ['SoftmaxGenerator'] + + +class SoftmaxGenerator(FollowingStrategyGenerator): + """ + SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + ''' + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = output_size_product * 2 + backward_compute_cost = input_size_product + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + softmax_dim = self.op_data['softmax_dim'].data + + if softmax_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 8f57ee6a0..6d68521aa 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import convert_dim_partition_dict class StrategyGenerator(ABC): @@ -67,21 +68,41 @@ class StrategyGenerator(ABC): Args: mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. + + Notes: + The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data. + However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as + list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned. """ results = {} for op_data_name, dim_partition_dict in mapping.items(): if op_data_name in self.op_data: op_data = self.op_data[op_data_name] - if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): - sharding_spec = [] - for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict): + + def _to_sharding_spec( + data: any, logical_shape: any, + dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + """ + This is a recursive function to convert the dim partition dict to a ShardingSpec object. + """ + if isinstance(data, torch.Tensor): + dim_size = len(logical_shape) + dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=output.shape, - dim_partition_dict=dim_partition_dict_element) - else: - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=op_data.logical_shape, - dim_partition_dict=dim_partition_dict) + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec + elif isinstance(data, (list, tuple)): + sharding_spec = [] + for data_element, logical_shape_element, dim_partition_dict_element in zip( + data, logical_shape, dim_partition_dict): + sharding_spec.append( + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + return sharding_spec + else: + return None + + sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict) results[op_data_name] = sharding_spec return results @@ -109,7 +130,8 @@ class StrategyGenerator(ABC): communication_pattern: CollectiveCommPattern, logical_process_axis: Union[int, List[int]], comm_type: CommType, - arg_index: int = -1) -> CommAction: + arg_index: int = -1, + key_for_kwarg: any = None) -> CommAction: """ A factory method to produce a CommAction object. """ @@ -117,7 +139,8 @@ class StrategyGenerator(ABC): communication_pattern=communication_pattern, logical_process_axis=logical_process_axis), comm_type=comm_type, - arg_index=arg_index) + arg_index=arg_index, + key_for_kwarg=key_for_kwarg) def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ @@ -180,13 +203,40 @@ class StrategyGenerator(ABC): Args: strategy (ShardingStrategy): the ShardingStrategy generated. key (str): the name of the operation data defined by the generator. - """ op_data = self.op_data[key] - sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() - dtype = self.op_data[key].data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - return reduce(operator.mul, sharded_shape) * size_per_elem_bytes + + def _compute_size_in_bytes_helper(sharding_spec, meta_data): + sharded_shape = sharding_spec.get_sharded_shape_per_device() + if len(sharded_shape) == 0: + num_elements = 1 + else: + num_elements = reduce(operator.mul, sharded_shape) + dtype = getattr(meta_data, 'dtype') + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + return num_elements * size_per_elem_bytes + + if isinstance(op_data.data, tuple): + assert isinstance(strategy.sharding_specs[op_data], list), \ + 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + total_bytes = 0 + for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): + meta_data = op_data.data[index] + if isinstance(meta_data, torch.Tensor): + element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) + else: + # if meta_data is not a tensor, we count the memroy as 0 + element_bytes = 0 + total_bytes += element_bytes + + else: + if isinstance(op_data.data, torch.Tensor): + total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) + else: + # if op_data.data is not a tensor, we count the memroy as 0 + total_bytes = 0 + + return total_bytes def generate(self) -> List[ShardingStrategy]: """ @@ -244,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator): def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]): - self.op_data = operation_data_mapping - self.device_mesh = device_mesh + super().__init__(operation_data_mapping, device_mesh) self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py new file mode 100644 index 000000000..a0fbc58d7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -0,0 +1,113 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['SumGenerator'] + + +class SumGenerator(FollowingStrategyGenerator): + """ + SumGenerator deals with the sharding strategies of torch.sum op. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + compute_cost = TrainCycleItem(fwd=input_size_product, + bwd=output_size_product, + total=input_size_product + output_size_product) + + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + + # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce + # among all the shard groups + recover_dims = [] + dim_partition_dict_for_output = {} + for dim in dim_partition_dict_for_input: + if dim in sum_dims: + recover_dims.append(dim) + elif dim in sum_mapping_dict: + dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] + else: + raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + + for dim in recover_dims: + dim_partition_dict_for_input.pop(dim) + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py new file mode 100644 index 000000000..93cfc9eee --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -0,0 +1,67 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +from .strategy_generator import StrategyGenerator + +__all__ = ['TensorConstructorGenerator'] + + +class TensorConstructorGenerator(StrategyGenerator): + """ + TensorConstructorGenerator which deals with + the sharding strategies for tensor constructor operation, such as torch.arange. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Tensor Constructor' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index 95c8e2efa..fa941f2cc 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -1,9 +1,12 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) from .strategy_generator import StrategyGenerator @@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def _generate_strategy_with_dim_partition(self, dim_partition): dim_partition_dict_mapping = { "condition": dim_partition, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py new file mode 100644 index 000000000..86f90694e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -0,0 +1,81 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, SumGenerator + +__all__ = ['SumHandler'] + + +@operator_registry.register(torch.Tensor.sum) +@operator_registry.register(torch.sum) +class SumHandler(NodeHandler): + """ + A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + if len(self.node.args) > 1: + sum_dims = self.node.args[1] + else: + sum_dims = tuple(range(self.node.args[0]._meta_data.dim())) + + if isinstance(sum_dims, int): + sum_dims = (sum_dims,) + + # recover negative value to positive + num_dims = self.node.args[0]._meta_data.dim() + for i in range(len(sum_dims)): + if sum_dims[i] < 0: + sum_dims[i] += num_dims + + # mapping the input dims to output dims + # For examples: + # input: torch.rand(2, 3, 4, 5) + # output: torch.sum(input, (0, 2)) + # sum_mapping_dict = {1: 0, 3: 1} + # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input + # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input + sum_mapping_dict = {} + if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + for i in range(num_dims): + sum_mapping_dict.update({i: i}) + else: + output_index = 0 + for i in range(num_dims): + if i not in sum_dims: + sum_mapping_dict.update({i: output_index}) + output_index += 1 + assert output_index == self.node._meta_data.dim() + + sum_info = (sum_dims, sum_mapping_dict) + physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "sum_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py new file mode 100644 index 000000000..855a2e761 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -0,0 +1,32 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator +from .strategy.tensor_constructor_generator import TensorConstructorGenerator + +__all__ = ['TensorConstructorHandler'] + + +@operator_registry.register(torch.arange) +class TensorConstructorHandler(NodeHandler): + """ + A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = {"output": physical_output_operand} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index b99d4a071..0362de780 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,16 +3,23 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator __all__ = ['UnaryElementwiseHandler'] +@operator_registry.register(torch.Tensor.to) +@operator_registry.register(torch.Tensor.type) @operator_registry.register(torch.abs) @operator_registry.register(torch.nn.ReLU) -class UnaryElementwiseHandler(NodeHandler): +@operator_registry.register(torch.nn.Tanh) +@operator_registry.register(torch.tanh) +@operator_registry.register(torch.nn.modules.dropout.Dropout) +@operator_registry.register(torch.Tensor.contiguous) +@operator_registry.register(torch.nn.functional.dropout) +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index ebcd6c453..6de2aaafd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List import torch -from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector) +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry @@ -57,32 +57,14 @@ class WhereHandler(NodeHandler): logical_operand.logical_shape = target_shape return logical_operand - def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector: - """ - Register different sharding strategies for the current node. - """ - strategy_generators = self.get_strategy_generator() - - for generator in strategy_generators: - strategies = generator.generate() - strategies_vector = map(self.post_process, strategies) - # compute the resharding costs based on the previous node - # strategies if specified - if compute_resharding_cost: - strategies = list(map(self.update_resharding_cost, strategies)) - self.strategies_vector.extend(strategies) - - self.strategies_vector = list(strategies_vector) - return self.strategies_vector - def post_process(self, strategy: ShardingStrategy): logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() for key in logical_op_data_mapping.keys(): logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] logical_shape = logical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape - physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape, - physical_shape) + physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec, logical_shape, physical_shape) strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 8dbb0014b..6af927272 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -6,10 +6,17 @@ from typing import Any, Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.tensor.shape_consistency import CommSpec +from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP +from .constants import ( + BCAST_FUNC_OP, + ELEMENTWISE_FUNC_OP, + ELEMENTWISE_METHOD_OP, + ELEMENTWISE_MODULE_OP, + RESHAPE_FUNC_OP, + RESHAPE_METHOD_OP, +) __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] @@ -43,8 +50,23 @@ class OperationData: def __post_init__(self): # if no logical shape is specified, use the data shape as the logical shape - if self.logical_shape is None and isinstance(self.data, torch.Tensor): - self.logical_shape = self.data.shape + if self.logical_shape is None: + + def _infer_logical_shape(data: any): + """ + This function is used to infer the logical shape of the data. + """ + if isinstance(data, torch.Tensor): + return data.shape + elif isinstance(data, torch.Size): + return None + elif isinstance(data, (tuple, list)): + data_type = type(data) + return data_type([_infer_logical_shape(d) for d in data]) + else: + return None + + self.logical_shape = _infer_logical_shape(self.data) def __repr__(self) -> str: return f'OperationData(name={self.name}, type={self.type})' @@ -79,9 +101,12 @@ class MemoryCost: Args: activation (int): the memory cost incurred by the activations in bytes. parameter (int): the memory cost incurred by the module parameter in bytes. + temp (int): the memory cost incurred by the temporary tensors in bytes. + buffer (int): the memory cost incurred by the module buffer in bytes. """ activation: int = 0 parameter: int = 0 + temp: int = 0 buffer: int = 0 @@ -115,6 +140,7 @@ class CommAction: comm_spec: CommSpec = None comm_type: CommType = None arg_index: int = -1 + key_for_kwarg: any = None @dataclass @@ -178,9 +204,15 @@ class ShardingStrategy: def _deepcopy_dict_vals(data: Dict): return {k: deepcopy(v) for k, v in data.items()} - sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None - communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None - resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None + sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None + # We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value. + # Consider the examples below: + # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. + # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. + communication_actions = _deepcopy_dict_vals( + self.communication_actions) if self.communication_actions is not None else None + # same reason as communication_actions + resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None compute_cost = deepcopy(self.compute_cost) communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) @@ -209,8 +241,6 @@ class StrategiesVector(list): # fetch its input and output nodes # TODO: placeholder input nodes self.predecessor_nodes = list(node._input_nodes.keys()) - if self.node.op == 'output': - self.predecessor_nodes = list(node._input_nodes.keys())[:1] self.successor_nodes = list(node.users.keys()) def check_merge(self): @@ -230,10 +260,18 @@ class StrategiesVector(list): if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. - if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: - merge_label = True - # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. + # TODO: remove this after we support the fall back logic. + # if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + # merge_label = True + # we could merge reshape op, because their computation costs are negligible. if self.node.target in RESHAPE_FUNC_OP: merge_label = True + if self.node.op == 'call_method': + # we could merge reshape op, because their computation costs are negligible. + method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + if method in RESHAPE_METHOD_OP: + merge_label = True + if method in ELEMENTWISE_METHOD_OP: + merge_label = True return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index abddbf2b0..038e56547 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -1,13 +1,14 @@ -from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST import torch +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + class CostGraph: ''' A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as + 2. To reduce the searching space, we merge computationally-trivial operators, such as element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will be given by the StrategiesVector depending on the type of target node and following nodes. @@ -62,16 +63,40 @@ class CostGraph: edge_cost[(j, i)] = resharding_cost_item.total self.edge_costs[node_pair] = edge_cost # add parents and children attribute to node - parent_nodes = [node for node in strategies_vector.predecessor_nodes] - children_nodes = [node for node in strategies_vector.successor_nodes] + # parent_nodes = [node for node in strategies_vector.predecessor_nodes] + # children_nodes = [node for node in strategies_vector.successor_nodes] + parent_nodes = [] + children_nodes = [] + + def _check_tensor_in_node(data): + """ + This method is used to check whether the data has a tensor inside or not. + """ + has_tensor_flag = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d) + return has_tensor_flag + + for node in strategies_vector.predecessor_nodes: + if _check_tensor_in_node(node._meta_data): + parent_nodes.append(node) + for node in strategies_vector.successor_nodes: + if _check_tensor_in_node(node._meta_data): + children_nodes.append(node) + setattr(dst_node, 'parents', parent_nodes) setattr(dst_node, 'children', children_nodes) - # self._remove_invalid_node(dst_node, 'parents') - # self._remove_invalid_node(dst_node, 'children') if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: - self.merge_pair.append((followed_node, dst_node)) + # we only merge node pairs which src node has a tensor element inside. + # This is necessay because the node without a tensor element inside will not + # be assigned any strategy. + if _check_tensor_in_node(followed_node._meta_data): + self.merge_pair.append((followed_node, dst_node)) def get_edge_cost(self, src_node, dst_node): return self.edge_costs[(src_node, dst_node)] @@ -79,14 +104,14 @@ class CostGraph: def merge_node(self, src_node, dst_node): ''' To merge dst_node into src_node, we need to do it in following steps: - + 1. For each strategy in dst_node, we need to pick an appropriate strategy - of src_node to merge, it is important because the logical resharding costs - between the parents node of src_node and merged node depend on the src_node + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node strategies dispatching. For example, for the graph 0->1->2, after merging node 1 into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] x represents the picking strategy of node 1 merged into node 2 strategy 0. - + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs contains two parts, one is resharding costs between src_node strategy and dst_node strategy, another is the origin extra costs in src_node strategy. @@ -98,10 +123,9 @@ class CostGraph: src_node(Node): The node will be merged into dst_node. dst_node(Node): The node to integrate src_node. ''' - src_node_index = dst_node.parents.index(src_node) # build merge_map merge_map = {} - for src_index, strategy in enumerate(src_node.strategies_vector): + for src_index, _ in enumerate(src_node.strategies_vector): min_cost = INFINITY_COST lowest_cost_index = -1 for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): @@ -139,7 +163,6 @@ class CostGraph: for i in range(self.node_lens[src_node]): for j in range(self.node_lens[child_node]): dst_strate_index = merge_map[i] - # dst_strategy = dst_node.strategies_vector[dst_strate_index] edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] if new_node_pair not in self.edge_costs: self.edge_costs[new_node_pair] = edge_cost diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py index 2d34f5c64..b52e55708 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/options.py +++ b/colossalai/auto_parallel/tensor_shard/solver/options.py @@ -1,11 +1,30 @@ from dataclasses import dataclass +from enum import Enum __all__ = ['SolverOptions'] +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + @dataclass class SolverOptions: """ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. """ - fast: bool = False + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index d6ce5e9fe..89d0da223 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -32,7 +32,8 @@ class Solver: memory_budget: float = -1.0, solution_numbers: int = 1, forward_only: bool = False, - memory_increasing_coefficient: float = 1.3): + memory_increasing_coefficient: float = 1.3, + verbose=True): ''' Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: @@ -64,6 +65,7 @@ class Solver: self.last_s_val = None # The last objective value of the best ILP solution. self.last_objective = None + self.verbose = verbose def _recover_merged_node_strategy(self): ''' @@ -152,12 +154,16 @@ class Solver: if self.forward_only: origin_communication_cost = communication_cost_item.fwd compute_cost = compute_cost_item.fwd + # extract MemoryCost item from the memory TrainCycleItem memory_cost = memory_cost_item.fwd else: origin_communication_cost = communication_cost_item.total compute_cost = compute_cost_item.total + # extract MemoryCost item from the memory TrainCycleItem memory_cost = memory_cost_item.total + # extract the memory cost in float from MemoryCost item and sum them up + memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer compute_costs.append(compute_cost) # node in extra_node_costs means it has some extra communication # cost from node merging, so we need to add those extra communication @@ -177,7 +183,7 @@ class Solver: # omit initial value for nodes s_init_np = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose def _call_solver_serialized_args(self, node_nums, @@ -192,7 +198,8 @@ class Solver: memory_costs, resharding_costs, alias_convert_costs, - s_init_np=None): + s_init_np=None, + verbose=True): """ Call the solver with serialized arguments. """ @@ -363,6 +370,8 @@ class Solver: for liveness_stage in liveness_set: mem = 0 for live_variable in liveness_stage.unique_live_vars: + if live_variable.node not in self.node_index_dict: + continue node_index = self.node_index_dict[live_variable.node] mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) prob += mem <= memory_budget @@ -407,8 +416,6 @@ class Solver: # if v[idx][row * C + col] > 0.5: # prob += s[i][row] + s[j][col] <= 1 - verbose = True - msg = verbose time_limit = 600 assert "COIN_CMD" in pulp.listSolvers( diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 57d5dfa79..042b9bb4b 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,3 +1,4 @@ +import builtins import math import operator from copy import deepcopy @@ -6,14 +7,17 @@ from typing import Dict, List import torch from torch.fx import Graph, Node -from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry) -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec) +from colossalai.auto_parallel.tensor_shard.node_handler import ( + GetattrHandler, + OutputHandler, + PlaceholderHandler, + operator_registry, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec -from .options import SolverOptions +from .options import DataloaderOption, SolverOptions __all__ = ['StrategiesConstructor'] @@ -37,6 +41,7 @@ class StrategiesConstructor: self.leaf_strategies = [] self.strategy_map = {} self.solver_options = solver_options + self.no_strategy_nodes = [] def remove_duplicated_strategy(self, strategies_vector): ''' @@ -47,10 +52,6 @@ class StrategiesConstructor: name_checklist = [] remove_list = [] for strategy in strategies_vector: - if strategy is None: - print(strategies_vector.node.name) - print(strategies_vector) - assert False if strategy.name not in name_checklist: name_checklist.append(strategy.name) else: @@ -62,34 +63,46 @@ class StrategiesConstructor: """ This method is to build the strategy vector for each node in the computation graph. """ + + def _check_no_strategy_for_node(node): + if node.op in ('placeholder', 'get_attr', 'output'): + return False + + def _check_no_strategy_for_data(data): + label = True + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (tuple, list)): + for d in data: + label = label and _check_no_strategy_for_data(d) + return label + + return _check_no_strategy_for_data(node._meta_data) + for node in self.nodes: strategies_vector = StrategiesVector(node) + + if _check_no_strategy_for_node(node): + self.no_strategy_nodes.append(node) + pass + # placeholder node - if node.op == 'placeholder': - placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector) + elif node.op == 'placeholder': + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + placeholder_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + placeholder_option = 'replicated' + placeholder_handler = PlaceholderHandler(node, + self.device_mesh, + strategies_vector, + placeholder_option=placeholder_option) placeholder_handler.register_strategy() # get_attr node - if node.op == 'get_attr': - # Same as placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the get_attr node. - # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. - if self.solver_options.fast: - # create sharding strategy for get_attr - name = 'Replica Attribute' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_attribute) - - # # get_attr node - # elif node.op == 'get_attr': - # # TODO: implement getattr node handler - # pass + elif node.op == 'get_attr': + getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) + getattr_handler.register_strategy() # call_module node elif node.op == 'call_module': @@ -98,28 +111,51 @@ class StrategiesConstructor: submod_type = type(submod) handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector) handler.register_strategy() + # attach metainfo_vector to node + if hasattr(handler, 'metainfo_vector'): + setattr(node, 'metainfo_vector', handler.metainfo_vector) # call_function node elif node.op == 'call_function': target = node.target handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector) handler.register_strategy() + # attach metainfo_vector to node + if hasattr(handler, 'metainfo_vector'): + setattr(node, 'metainfo_vector', handler.metainfo_vector) # call_method node elif node.op == 'call_method': method = getattr(node.args[0]._meta_data.__class__, node.target) handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector) handler.register_strategy() + # attach metainfo_vector to node + if hasattr(handler, 'metainfo_vector'): + setattr(node, 'metainfo_vector', handler.metainfo_vector) # output node elif node.op == 'output': - output_handler = OuputHandler(node, self.device_mesh, strategies_vector) + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + output_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + output_option = 'replicated' + output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() - if len(strategies_vector) <= 0: - print(node.name) - assert len(strategies_vector) > 0 self.remove_duplicated_strategy(strategies_vector) setattr(node, 'strategies_vector', strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index 380464bcd..b7fe5430b 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -1,6 +1,13 @@ -from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape +from .broadcast import ( + BroadcastType, + comm_actions_for_oprands, + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) from .factory import generate_resharding_costs, generate_sharding_spec -from .misc import check_sharding_spec_validity, ignore_sharding_exception +from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map +from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict from .sharding import ( enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, @@ -13,5 +20,6 @@ __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', + 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index a0edce9b9..28aa55132 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -2,10 +2,21 @@ from enum import Enum, auto from typing import List import torch +from torch.fx.node import Node +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, +) +from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] +__all__ = [ + 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', + 'comm_actions_for_oprands' +] class BroadcastType(Enum): @@ -44,16 +55,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: return dims[::-1] -def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, - physical_shape: torch.Size) -> ShardingSpec: - """ - This function computes the sharding spec for the physical shape of a broadcast tensor. - - Args: - logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor - logical_shape (torch.Size): logical shape is the broadcast shape of a tensor - physical_shape (torch.Size): the shape of the tensor before broadcasting - """ +def get_broadcast_dim_info(logical_shape, physical_shape): # get the number of dimensions logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) @@ -80,6 +82,34 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe else: logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING + return logical_dim_broadcast_info + + +def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, + physical_shape: torch.Size) -> ShardingSpec: + """ + This function computes the sharding spec for the physical shape of a broadcast tensor. + + Args: + logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor + logical_shape (torch.Size): logical shape is the broadcast shape of a tensor + physical_shape (torch.Size): the shape of the tensor before broadcasting + """ + # if the two shapes are the same, no broadcast occurs + # we directly return the current sharding spec + + # recording the sharding dimensions removed during logical shape converting to physical one + removed_dims = [] + if list(logical_shape) == list(physical_shape): + return logical_sharding_spec, removed_dims + + # get the number of dimensions + logical_num_dims = len(logical_shape) + physical_num_dims = len(physical_shape) + + # get the broadcast info + logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape) + # generate the sharding spec for the physical shape physical_dim_partition = {} logical_dim_partition = logical_sharding_spec.dim_partition_dict @@ -88,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe logical_broadcast_type = logical_dim_broadcast_info[shape_dim] if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: - pass + removed_dims.extend(mesh_dim) else: # get the corresponding physical dim physical_dim = physical_num_dims - (logical_num_dims - shape_dim) @@ -98,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe entire_shape=physical_shape, dim_partition_dict=physical_dim_partition) - return physical_sharding_spec + return physical_sharding_spec, removed_dims + + +def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, + sharding_spec: ShardingSpec) -> CommAction: + """ + This method is used to generate communication actions for oprands which lose information + during convert logical shape to physical shape. + """ + if len(removed_dims) == 1: + # if list length is 1, extract element from list to avoid using flatten device mesh + removed_dims = removed_dims[0] + comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims) + if op_data.type == OperationDataType.PARAM: + comm_type = CommType.HOOK + else: + comm_type = CommType.BEFORE + arg_index = -1 + for index, arg in enumerate(node.args): + if op_data.name == str(arg): + arg_index = index + assert arg_index >= 0, f'op_data should be an argument of node.' + comm_action = CommAction( + comm_spec=comm_spec, + comm_type=comm_type, + arg_index=arg_index, + ) + return comm_action diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 967847390..9e402dab7 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,11 +1,12 @@ import functools +from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -__all__ = ['ignore_sharding_exception'] +__all__ = ['ignore_sharding_exception', 'pytree_map'] def ignore_sharding_exception(func): @@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure the entire shape matches the physical tensor shape assert sharding_spec.entire_shape == tensor.shape, \ f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py new file mode 100644 index 000000000..a32a14bf7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -0,0 +1,192 @@ +from enum import Enum +from typing import Dict, List, Tuple + +import torch + + +class PreviousStatus(Enum): + """ + This class shows the status of previous comparision. + """ + RESET = 0 + # ORIGIN means the dimension size of original tensor is larger in the previous comparision. + ORIGIN = 1 + # TGT means the dimension size of target tensor is larger in the previous comparision. + TGT = 2 + + +def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to detect the reshape mapping between original tensor and target tensor. + + Returns: + reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related + target dims(values) during reshaping operation. + Examples: + import torch + origin_shape = torch.Size([4, 4, 4]) + tgt_shape = torch.Size([2, 8, 2, 2]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + print(reshape_mapping_dict) + Output: + {(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)} + """ + + # reverse the shape object + origin_shape = list(origin_shape) + tgt_shape = list(tgt_shape) + origin_shape.reverse() + tgt_shape.reverse() + + # initialize arguments + reshape_mapping_dict = {} + origin_len = len(origin_shape) + tgt_len = len(tgt_shape) + origin_index = 0 + tgt_index = 0 + original_dimension_size = origin_shape[origin_index] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + origin_dims = [origin_len - origin_index - 1] + previous_label = PreviousStatus.RESET + + while origin_index != len(origin_shape) or tgt_index != len(tgt_shape): + if original_dimension_size == tgt_dimension_size: + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + # if the origin_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the origin_index for that case. + if len(origin_dims) > 0: + origin_index += 1 + # if the tgt_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the tgt_index for that case. + if len(tgt_dims) > 0: + tgt_index += 1 + # the last step of loop should always end with condition + # so we need to manually skip the preparation for next step + # in the last step. + if origin_index == len(origin_shape) and tgt_index == len(tgt_shape): + continue + + # If origin_index equals to origin_len, we just need to set the original_dimension_size + # to 1 to match the remaining '1's in the target tensor shape. + if origin_index == len(origin_shape): + original_dimension_size = 1 + origin_dims = [] + else: + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1] + + # If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size + # to 1 to match the remaining '1's in the original tensor shape. + if tgt_index == len(tgt_shape): + tgt_dimension_size = 1 + tgt_dims = [] + else: + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + + previous_label = PreviousStatus.RESET + + elif original_dimension_size > tgt_dimension_size: + tgt_index += 1 + + if previous_label == PreviousStatus.TGT: + # if the target dimension size is larger in the previous comparision, which means + # the origin dimension size has already accumulated larger than target dimension size, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + original_dimension_size = original_dimension_size // tgt_dimension_size + origin_dims = [origin_len - origin_index - 1] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size + tgt_dimension_size *= tgt_shape[tgt_index] + tgt_dims.append(tgt_len - tgt_index - 1) + previous_label = PreviousStatus.ORIGIN + + else: + origin_index += 1 + + if previous_label == PreviousStatus.ORIGIN: + # if the origin element is larger in the previous comparision, which means + # the target element has already accumulated larger than origin element, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + tgt_dimension_size = tgt_dimension_size // original_dimension_size + tgt_dims = [tgt_len - tgt_index - 1] + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1, origin_len - origin_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size + original_dimension_size *= origin_shape[origin_index] + origin_dims.append(origin_len - origin_index - 1) + previous_label = PreviousStatus.TGT + + return reshape_mapping_dict + + +def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: + """ + This method is used to check whether the reshape operation could implement without converting + the input to fully replicated status. + + Rule: + For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, + the function will return false. + To illustrate this issue, there are two cases to analyse: + 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal + operation without distributed tensor. + 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape + consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded + dim get recovered. + + Examples: + # the second dimension of the input has been sharded. + input_dim_partition_dict = {1: [1]} + origin_shape = torch.Size([8, 4, 2]) + tgt_shape = torch.Size([2, 4, 8]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + # {(2, 1): (2,), (0,): (1, 0)} + # the sharded dim of input is 1, which is the minimum element of the tuple (2, 1), + # so we do not have to convert the input to fully replicated status. + print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict)) + + Output: + True + """ + sharded_dims = list(input_dim_partition_dict.keys()) + for input_dims in reshape_mapping_dict.keys(): + # if input_dims has no element, we could just skip this iteration. + if len(input_dims) == 0: + continue + min_element = min(input_dims) + for dim in input_dims: + if dim in sharded_dims and dim is not min_element: + return False + return True + + +def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to infer the output dim partition dict for a reshape operation, + given the input dim partition dict and reshape mapping dict. + """ + assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ + 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + sharded_dims = list(input_dim_partition_dict.keys()) + output_dim_partition_dict = {} + for input_dims, output_dims in reshape_mapping_dict.items(): + for dim in input_dims: + if dim in sharded_dims: + output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim] + # we could break because input dims cannot contain two sharded dims, otherwise + # the keep sharding status check will fail. + break + return output_dim_partition_dict diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py index c020d33b6..618ff8c61 100644 --- a/colossalai/cli/benchmark/__init__.py +++ b/colossalai/cli/benchmark/__init__.py @@ -1,9 +1,10 @@ import click -from .utils import * -from .benchmark import run_benchmark from colossalai.context import Config +from .benchmark import run_benchmark +from .utils import * + __all__ = ['benchmark'] diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py index 43632b150..f40f8f2f9 100644 --- a/colossalai/cli/benchmark/benchmark.py +++ b/colossalai/cli/benchmark/benchmark.py @@ -1,16 +1,17 @@ -import colossalai +from functools import partial +from typing import Dict, List + import click import torch.multiprocessing as mp -from functools import partial -from typing import List, Dict - +import colossalai +from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model from colossalai.context import Config from colossalai.context.random import reset_seeds from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import free_port, MultiTimer -from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data +from colossalai.utils import MultiTimer, free_port + from .models import MLP @@ -53,7 +54,7 @@ def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_ port_list (List[int]): a list of free ports for initializing distributed networks config_list (List[Dict]): a list of configuration hyperparams (Config): the hyperparameters given by the user - + """ # disable logging for clean output diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index eab0bc1ed..22c169577 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -1,33 +1,194 @@ -import click import subprocess + +import click import torch from torch.utils.cpp_extension import CUDA_HOME +import colossalai + + +def to_click_output(val): + # installation check output to understandable symbols for readability + VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} + + if val in VAL_TO_SYMBOL: + return VAL_TO_SYMBOL[val] + else: + return val + def check_installation(): - cuda_ext_installed = _check_cuda_extension_installed() - cuda_version, torch_version, torch_cuda_version, cuda_torch_compatibility = _check_cuda_torch() + """ + This function will check the installation of colossalai, specifically, the version compatibility of + colossalai, pytorch and cuda. - click.echo(f"CUDA Version: {cuda_version}") - click.echo(f"PyTorch Version: {torch_version}") - click.echo(f"CUDA Version in PyTorch Build: {torch_cuda_version}") - click.echo(f"PyTorch CUDA Version Match: {cuda_torch_compatibility}") - click.echo(f"CUDA Extension: {cuda_ext_installed}") + Example: + ```text + ``` + + Returns: A table of installation information. + """ + found_aot_cuda_ext = _check_aot_built_cuda_extension_installed() + cuda_version = _check_cuda_version() + torch_version, torch_cuda_version = _check_torch_version() + colossalai_verison, torch_version_required, cuda_version_required = _parse_colossalai_version() + + # if cuda_version is None, that means either + # CUDA_HOME is not found, thus cannot compare the version compatibility + if not cuda_version: + sys_torch_cuda_compatibility = None + else: + sys_torch_cuda_compatibility = _is_compatible([cuda_version, torch_cuda_version]) + + # if cuda_version or cuda_version_required is None, that means either + # CUDA_HOME is not found or AOT compilation is not enabled + # thus, there is no need to compare the version compatibility at all + if not cuda_version or not cuda_version_required: + sys_colossalai_cuda_compatibility = None + else: + sys_colossalai_cuda_compatibility = _is_compatible([cuda_version, cuda_version_required]) + + # if torch_version_required is None, that means AOT compilation is not enabled + # thus there is no need to compare the versions + if torch_version_required is None: + torch_compatibility = None + else: + torch_compatibility = _is_compatible([torch_version, torch_version_required]) + + click.echo(f'#### Installation Report ####') + click.echo(f'\n------------ Environment ------------') + click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}") + click.echo(f"PyTorch version: {to_click_output(torch_version)}") + click.echo(f"CUDA version: {to_click_output(cuda_version)}") + click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}") + click.echo("") + click.echo(f"Note:") + click.echo(f"1. The table above checks the versions of the libraries/tools in the current environment") + click.echo(f"2. If the CUDA version is N/A, you can set the CUDA_HOME environment variable to locate it") + + click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') + click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") + click.echo(f"PyTorch version used for AOT compilation: {to_click_output(torch_version_required)}") + click.echo(f"CUDA version used for AOT compilation: {to_click_output(cuda_version_required)}") + click.echo("") + click.echo(f"Note:") + click.echo( + f"1. AOT (ahead-of-time) compilation of the CUDA kernels occurs during installation when the environment varialbe CUDA_EXT=1 is set" + ) + click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") + + click.echo(f"\n------------ Compatibility ------------") + click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') + click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") + click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") + click.echo(f"") + click.echo(f"Note:") + click.echo(f"1. The table above checks the version compatibility of the libraries/tools in the current environment") + click.echo( + f" - PyTorch version mistach: whether the PyTorch version in the current environment is compatible with the PyTorch version used for AOT compilation" + ) + click.echo( + f" - System and PyTorch CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version required by PyTorch" + ) + click.echo( + f" - System and Colossal-AI CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version used for AOT compilation" + ) -def _check_cuda_extension_installed(): +def _is_compatible(versions): + """ + Compare the list of versions and return whether they are compatible. + """ + if None in versions: + return False + + # split version into [major, minor, patch] + versions = [version.split('.') for version in versions] + + for version in versions: + if len(version) == 2: + # x means unknown + version.append('x') + + for idx, version_values in enumerate(zip(*versions)): + equal = len(set(version_values)) == 1 + + if idx in [0, 1] and not equal: + return False + elif idx == 1: + return True + else: + continue + + +def _parse_colossalai_version(): + """ + Get the Colossal-AI version information. + + Returns: + colossalai_version: Colossal-AI version. + torch_version_for_aot_build: PyTorch version used for AOT compilation of CUDA kernels. + cuda_version_for_aot_build: CUDA version used for AOT compilation of CUDA kernels. + """ + # colossalai version can be in two formats + # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) + # 2. X.X.X (when colossalai is not installed with CUDA extensions) + # where X represents an integer. + colossalai_verison = colossalai.__version__.split('+')[0] + try: - import colossal_C - is_cuda_extension_installed = u'\u2713' + torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] + cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] + except: + torch_version_for_aot_build = None + cuda_version_for_aot_build = None + return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build + + +def _check_aot_built_cuda_extension_installed(): + """ + According to `op_builder/README.md`, the CUDA extension can be built with either + AOT (ahead-of-time) or JIT (just-in-time) compilation. + AOT compilation will build CUDA extensions to `colossalai._C` during installation. + JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. + """ + try: + import colossalai._C.fused_optim + found_aot_cuda_ext = True except ImportError: - is_cuda_extension_installed = 'x' - return is_cuda_extension_installed + found_aot_cuda_ext = False + return found_aot_cuda_ext -def _check_cuda_torch(): +def _check_torch_version(): + """ + Get the PyTorch version information. + + Returns: + torch_version: PyTorch version. + torch_cuda_version: CUDA version required by PyTorch. + """ + # get torch version + torch_version = torch.__version__.split('+')[0] + + # get cuda version in pytorch build + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + + return torch_version, torch_cuda_version + + +def _check_cuda_version(): + """ + Get the CUDA version information. + + Returns: + cuda_version: CUDA version found on the system. + """ # get cuda version if CUDA_HOME is None: - cuda_version = 'N/A (CUDA_HOME is not set)' + cuda_version = CUDA_HOME else: raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() @@ -36,22 +197,4 @@ def _check_cuda_torch(): bare_metal_major = release[0] bare_metal_minor = release[1][0] cuda_version = f'{bare_metal_major}.{bare_metal_minor}' - - # get torch version - torch_version = torch.__version__ - - # get cuda version in pytorch build - torch_cuda_major = torch.version.cuda.split(".")[0] - torch_cuda_minor = torch.version.cuda.split(".")[1] - torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' - - # check version compatiblity - cuda_torch_compatibility = 'x' - if CUDA_HOME: - if torch_cuda_major == bare_metal_major: - if torch_cuda_minor == bare_metal_minor: - cuda_torch_compatibility = u'\u2713' - else: - cuda_torch_compatibility = u'\u2713 (minor version mismatch)' - - return cuda_version, torch_version, torch_cuda_version, cuda_torch_compatibility + return cuda_version diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 2c9e9927c..64fb5b8b5 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -3,12 +3,17 @@ import torch import torch.distributed as dist -from torch.distributed import ReduceOp from torch import Tensor +from torch.distributed import ReduceOp from colossalai.context import ParallelMode from colossalai.core import global_context as gpc +_all_gather_func = dist._all_gather_base \ + if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = dist._reduce_scatter_base \ + if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor + def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: r"""Gathers all tensors from the parallel group and concatenates them in a @@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: out = tensor work = None else: - shape = list(tensor.shape) - shape[0], shape[dim] = shape[dim], shape[0] - shape[0] *= depth - out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) - temp = list(torch.chunk(out, depth, dim=0)) + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) - work = dist.all_gather(tensor_list=temp, - tensor=tensor.transpose(0, dim).contiguous(), - group=group, - async_op=async_op) - out = torch.transpose(out, 0, dim) + work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) if async_op: return out, work else: @@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor, out = tensor work = None else: - temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) - out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) - work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) + work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) if async_op: return out, work else: @@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: - r"""Modified from `torch.distributed.scatter_object_list ` to fix issues + r"""Modified from `torch.distributed.scatter_object_list + ` to fix issues """ if dist.distributed_c10d._rank_not_in_group(group): return diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py index 0b575e7db..4223f78d5 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/communication/p2p_v2.py @@ -1,14 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple, Union, Any -import pickle import io +import pickle +from typing import Any, List, Tuple, Union import torch import torch.distributed as dist -from torch.distributed import distributed_c10d as c10d from torch.distributed import ProcessGroupNCCL +from torch.distributed import distributed_c10d as c10d from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -23,7 +23,7 @@ def init_process_group(): Args: None - + Returns: None """ @@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou second_rank (int): second rank in the pair Returns: - :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks + :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks """ if len(_pg_manager) == 0: init_process_group() @@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: - """transform tensor to object with unpickle. - Info of the device in bytes stream will be modified into current device before unpickling + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling Args: tensor (:class:`torch.tensor`): tensor to be unpickled @@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None): """This is a modified version of the broadcast_object_list in torch.distribution The only difference is that object will be move to correct device after unpickled. - If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will be updated with data sent from rank src. - + Args: object_list (List[Any]): list of object to broadcast src (int): source rank to broadcast @@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any: Args: src (int): source rank of data. local rank will receive data from src rank. - + Returns: Any: Object received from src. """ diff --git a/colossalai/constants.py b/colossalai/constants.py index c8aaafdfa..6cf9085f9 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -23,6 +23,8 @@ INITIALIZER_MAPPING = { INPUT_GROUP_3D = 'input_group_3d' WEIGHT_GROUP_3D = 'weight_group_3d' OUTPUT_GROUP_3D = 'output_group_3d' +INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' +OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' # Attributes of tensor parallel parameters IS_TENSOR_PARALLEL = 'is_tensor_parallel' diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index afa306065..b7338b53d 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -370,12 +370,12 @@ class ParallelContext(metaclass=SingletonMeta): port (str): the master port for distributed training """ # initialize the default process group - init_method = f'tcp://{host}:{port}' + init_method = f'tcp://[{host}]:{port}' dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) - cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() == 'gloo' else None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py index dc50dca05..1cf6fa53d 100644 --- a/colossalai/context/parallel_mode.py +++ b/colossalai/context/parallel_mode.py @@ -39,6 +39,8 @@ class ParallelMode(Enum): PARALLEL_3D_INPUT = '3d_input' PARALLEL_3D_WEIGHT = '3d_weight' PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" + PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" # 2.5D parallel PARALLEL_2P5D_ROW = '2p5d_row' diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index 0cda7a52d..b752b8f45 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -176,6 +176,112 @@ class Initializer_3D_Output(ProcessGroupInitializer): return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode +class Initializer_3D_InputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT + env.input_x_weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_OutputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT + env.output_x_weight_group_3d = mode + + for h in range(self.num_group): + for j in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + @DIST_GROUP_INITIALIZER.register_module class Initializer_3D(ProcessGroupInitializer): """Serve as the single entry point to 3D parallel initialization. @@ -200,6 +306,8 @@ class Initializer_3D(ProcessGroupInitializer): self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args) self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args) + self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args) + self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args) def init_dist_group(self): """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. @@ -211,6 +319,8 @@ class Initializer_3D(ProcessGroupInitializer): parallel_setting = [ self.input_initializer.init_dist_group(), self.weight_initializer.init_dist_group(), - self.output_initializer.init_dist_group() + self.output_initializer.init_dist_group(), + self.input_x_weight_initializer.init_dist_group(), + self.output_x_weight_initializer.init_dist_group() ] return parallel_setting diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index e69de29bb..689189998 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -0,0 +1,4 @@ +from .alpha_beta_profiler import AlphaBetaProfiler +from .calc_pipeline_strategy import alpa_dp + +__all__ = ['AlphaBetaProfiler', 'alpa_dp'] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py new file mode 100644 index 000000000..9c66cb85d --- /dev/null +++ b/colossalai/device/alpha_beta_profiler.py @@ -0,0 +1,386 @@ +import math +import time +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist + +from colossalai.logging import get_dist_logger + +GB = int((1 << 30)) +BYTE = 4 +FRAMEWORK_LATENCY = 0 + + +class AlphaBetaProfiler: + ''' + Profile alpha and beta value for a given device list. + + Usage: + # Note: the environment of execution is supposed to be + # multi-process with multi-gpu in mpi style. + >>> physical_devices = [0, 1, 4, 5] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> ab_dict = profiler.alpha_beta_dict + >>> print(ab_dict) + {(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11), + (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), + (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), + (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} + ''' + + def __init__(self, + physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, + ctype: str = 'a', + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1): + ''' + Args: + physical_devices: A list of device id, each element inside it is the global rank of that device. + alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. + ctype: 'a' for all-reduce, 'b' for broadcast. + warmup: Number of warmup iterations. + repeat: Number of iterations to measure. + latency_iters: Number of iterations to measure latency. + ''' + self.physical_devices = physical_devices + self.ctype = ctype + self.world_size = len(physical_devices) + self.warmup = warmup + self.repeat = repeat + self.latency_iters = latency_iters + self.homogeneous_tolerance = homogeneous_tolerance + self.process_group_dict = None + self._init_profiling() + if alpha_beta_dict is None: + self.alpha_beta_dict = self.profile_ab() + else: + self.alpha_beta_dict = alpha_beta_dict + + def _init_profiling(self): + # Create process group list based on its global rank + process_group_list = [] + for f_index in range(self.world_size - 1): + for b_index in range(f_index + 1, self.world_size): + process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index])) + + # Create process group dict which maps process group to its handler + process_group_dict = {} + for process_group in process_group_list: + pg_handler = dist.new_group(process_group) + process_group_dict[process_group] = pg_handler + + self.process_group_dict = process_group_dict + + def _profile(self, process_group, pg_handler, nbytes): + logger = get_dist_logger() + rank = dist.get_rank() + src_device_num = process_group[0] + world_size = len(process_group) + + device = torch.cuda.current_device() + buf = torch.randn(nbytes // 4).to(device) + + torch.cuda.synchronize() + # warmup + for _ in range(self.warmup): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + + dist.barrier(group=pg_handler) + begin = time.perf_counter() + for _ in range(self.repeat): + if self.ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) + elif self.ctype == "b": + dist.broadcast(buf, src=src_device_num, group=pg_handler) + torch.cuda.synchronize() + end = time.perf_counter() + dist.barrier(group=pg_handler) + + if rank == src_device_num: + avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY + alg_band = nbytes / avg_time_s + if self.ctype == "a": + # convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware. + bus_band = 2 * (world_size - 1) / world_size * alg_band + bus_band = alg_band + elif self.ctype == "b": + bus_band = alg_band + + logger.info( + f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" + ) + return (avg_time_s, alg_band) + else: + # Just a placeholder + return (None, None) + + def profile_latency(self, process_group, pg_handler): + ''' + This function is used to profile the latency of the given process group with a series of bytes. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + + Returns: + latency: None if the latency is not measured, otherwise the median of the latency_list. + ''' + latency_list = [] + for i in range(self.latency_iters): + nbytes = int(BYTE << i) + (t, _) = self._profile(process_group, pg_handler, nbytes) + latency_list.append(t) + + if latency_list[0] is None: + latency = None + else: + median_index = math.floor(self.latency_iters / 2) + latency = latency_list[median_index] + + return latency + + def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): + ''' + This function is used to profile the bandwidth of the given process group. + + Args: + process_group: A tuple of global rank of the process group. + pg_handler: The handler of the process group. + ''' + (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) + return bandwidth + + def profile_ab(self): + ''' + This method is used to profiling the alpha and beta value for a given device list. + + Returns: + alpha_beta_dict: A dict which maps process group to its alpha and beta value. + ''' + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} + rank = dist.get_rank() + global_pg_handler = dist.new_group(self.physical_devices) + + def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): + assert rank in process_group + device = torch.cuda.current_device() + rank_max_nbytes = torch.cuda.mem_get_info(device)[0] + rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device) + dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler) + max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB)))) + return max_nbytes + + for process_group, pg_handler in self.process_group_dict.items(): + if rank not in process_group: + max_nbytes = None + alpha = None + bandwidth = None + else: + max_nbytes = get_max_nbytes(process_group, pg_handler) + alpha = self.profile_latency(process_group, pg_handler) + bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes) + + if bandwidth is None: + beta = None + else: + beta = 1 / bandwidth + + broadcast_list = [alpha, beta] + dist.broadcast_object_list(broadcast_list, src=process_group[0]) + alpha_beta_dict[process_group] = tuple(broadcast_list) + + # add symmetry pair to the apha_beta_dict + symmetry_ab_dict = {} + for process_group, alpha_beta_pair in alpha_beta_dict.items(): + symmetry_process_group = (process_group[1], process_group[0]) + symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair + + alpha_beta_dict.update(symmetry_ab_dict) + + return alpha_beta_dict + + def search_best_logical_mesh(self): + ''' + This method is used to search the best logical mesh for the given device list. + + The best logical mesh is searched in following steps: + 1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict + are homogeneous if the beta value is close enough. + 2. Find the best homogeneous device group contains all the physical devices. The best homogeneous + device group means the lowest beta value in the groups which contains all the physical devices. + And the reason we require the group contains all the physical devices is that the devices not in + the group will decrease the bandwidth of the group. + 3. If the best homogeneous device group is found, we will construct the largest ring for each device + based on the best homogeneous device group, and the best logical mesh will be the union of all the + rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for + 4 devices. + + Returns: + best_logical_mesh: The best logical mesh for the given device list. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> best_logical_mesh = profiler.search_best_logical_mesh() + >>> print(best_logical_mesh) + [[0, 1], [2, 3]] + ''' + + def _power_of_two(integer): + return integer & (integer - 1) == 0 + + def _detect_homogeneous_device(alpha_beta_dict): + ''' + This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. + + Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value + of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] + * base_beta. + ''' + homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} + for process_group, (_, beta) in alpha_beta_dict.items(): + if homogeneous_device_dict is None: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + match_beta = None + for beta_value in homogeneous_device_dict.keys(): + if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( + 1 - self.homogeneous_tolerance): + match_beta = beta_value + break + + if match_beta is not None: + homogeneous_device_dict[match_beta].append(process_group) + else: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + return homogeneous_device_dict + + def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to check whether the homogeneous_group contains all physical devices. + ''' + flatten_mesh = [] + for process_group in homogeneous_group: + flatten_mesh.extend(process_group) + non_duplicated_flatten_mesh = set(flatten_mesh) + return len(non_duplicated_flatten_mesh) == len(self.physical_devices) + + def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to construct the largest ring in the homogeneous_group for each rank. + ''' + # Construct the ring + ring = [] + ranks_in_ring = [] + for rank in self.physical_devices: + if rank in ranks_in_ring: + continue + stable_status = False + ring_for_rank = [] + ring_for_rank.append(rank) + check_rank_list = [rank] + rank_to_check_list = [] + + while not stable_status: + stable_status = True + check_rank_list.extend(rank_to_check_list) + rank_to_check_list = [] + for i in range(len(check_rank_list)): + check_rank = check_rank_list.pop() + for process_group in homogeneous_group: + if check_rank in process_group: + rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + if rank_to_append not in ring_for_rank: + stable_status = False + rank_to_check_list.append(rank_to_append) + ring_for_rank.append(rank_to_append) + + ring.append(ring_for_rank) + ranks_in_ring.extend(ring_for_rank) + + return ring + + assert _power_of_two(self.world_size) + power_of_two = int(math.log2(self.world_size)) + median = power_of_two // 2 + balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] + balanced_logical_mesh = [] + for row_index in range(row_size): + balanced_logical_mesh.append([]) + for column_index in range(column_size): + balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index]) + + homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict) + beta_list = [b for b in homogeneous_device_dict.keys()] + beta_list.sort() + beta_list.reverse() + homogeneous_types = len(beta_list) + best_logical_mesh = None + if homogeneous_types >= 2: + for _ in range(homogeneous_types - 1): + lowest_beta = beta_list.pop() + best_homogeneous_group = homogeneous_device_dict[lowest_beta] + # if the best homogeneous group contains all physical devices, + # we will build the logical device mesh based on it. Otherwise, + # we will check next level homogeneous group. + if _check_contain_all_devices(best_homogeneous_group): + # We choose the largest ring for each rank to maximum the best bus utilization. + best_logical_mesh = _construct_largest_ring(best_homogeneous_group) + break + + if homogeneous_types == 1 or best_logical_mesh is None: + # in this case, we use balanced logical mesh as the best + # logical mesh. + best_logical_mesh = balanced_logical_mesh + + return best_logical_mesh + + def extract_alpha_beta_for_device_mesh(self): + ''' + Extract the mesh_alpha list and mesh_beta list based on the + best logical mesh, which will be used to initialize the device mesh. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + >>> print(mesh_alpha) + [2.5917552411556242e-05, 0.00010312341153621673] + >>> print(mesh_beta) + [5.875573704655635e-11, 4.7361584445959614e-12] + ''' + best_logical_mesh = self.search_best_logical_mesh() + + first_axis = [row[0] for row in best_logical_mesh] + second_axis = best_logical_mesh[0] + + # init process group for both axes + first_axis_process_group = dist.new_group(first_axis) + second_axis_process_group = dist.new_group(second_axis) + + # extract alpha and beta for both axes + def _extract_alpha_beta(pg, pg_handler): + latency = self.profile_latency(pg, pg_handler) + bandwidth = self.profile_bandwidth(pg, pg_handler) + broadcast_object = [latency, bandwidth] + dist.broadcast_object_list(broadcast_object, src=pg[0]) + return broadcast_object + + first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) + second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) + mesh_alpha = [first_latency, second_latency] + mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth] + + return mesh_alpha, mesh_beta diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py new file mode 100644 index 000000000..4ab72dfe6 --- /dev/null +++ b/colossalai/device/calc_pipeline_strategy.py @@ -0,0 +1,127 @@ +from math import pow + +import numpy as np + + +def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): + submesh_choices = [] + i = 1 + p = -1 + while i <= num_devices_per_host: + i *= 2 + p += 1 + assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}") + if mode == "alpa": + for i in range(p + 1): + submesh_choices.append((1, pow(2, i))) + for i in range(2, num_hosts + 1): + submesh_choices.append((i, num_devices_per_host)) + elif mode == "new": + for i in range(p // 2 + 1): + for j in range(i, p - i + 1): + submesh_choices.append((pow(2, i), pow(2, j))) + return submesh_choices + + +def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, + best_configs): + """Implementation of Alpa DP for pipeline strategy + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ + # For f, layer ID start from 0 + # f[#pipeline stages, layer id that is currently being considered, number of devices used] + f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) + f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) + f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) + f[0, num_layers, 0] = 0 + for s in range(1, num_layers + 1): + for k in range(num_layers - 1, -1, -1): + for d in range(1, num_devices + 1): + for m, submesh in enumerate(submesh_choices): + n_submesh_devices = np.prod(np.array(submesh)) + if n_submesh_devices <= d: + # TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete. + # if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]: + # ... + for i in range(num_layers, k, -1): + stage_cost = compute_cost[k, i, m] + new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost + if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + f[s, k, d] = new_cost + f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) + f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) + best_s = -1 + best_total_cost = np.inf + for s in range(1, num_layers + 1): + if f[s, 0, num_devices] < best_total_cost: + best_s = s + best_total_cost = f[s, 0, num_devices] + + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] + current_s = best_s + current_layer = 0 + current_devices = num_devices + + res = [] + while current_s > 0 and current_layer < num_layers and current_devices > 0: + next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + assert next_start_layer != -1 and current_devices != -1 + res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) + current_s -= 1 + current_layer = next_start_layer + current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) + assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + + return total_cost, res + + +def alpa_dp(num_layers, + num_devices, + num_microbatches, + submesh_choices, + num_autosharding_configs, + compute_cost, + gap=1e-6): + """Alpa auto stage dynamic programming. + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + + Arguments: + submesh_choices: List[(int,int)] + num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) + compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) + """ + assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), + num_autosharding_configs), "Cost shape wrong." + all_possible_stage_costs = np.sort(np.unique(compute_cost)) + best_cost = np.inf + best_solution = None + last_max_stage_cost = 0.0 + # TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost? + # In dp_impl it seems the argmin n_config will be chosen. Just amin here. + best_configs = np.argmin(compute_cost, axis=3) + best_compute_cost = np.amin(compute_cost, axis=3) + assert len(all_possible_stage_costs), "no solution in auto stage construction." + for max_stage_cost in all_possible_stage_costs: + if max_stage_cost * num_microbatches >= best_cost: + break + if max_stage_cost - last_max_stage_cost < gap: + continue + cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, + max_stage_cost, best_configs) + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_stage_cost = max_stage_cost + + return best_cost, best_solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index df010e7d7..7596a100b 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -1,5 +1,6 @@ -from functools import reduce import operator +from functools import reduce + import torch import torch.distributed as dist @@ -11,7 +12,7 @@ class DeviceMesh: can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its own latency and bandwidth. We use alpha-beta model to model the communication cost. - + Arguments: physical_mesh_id (torch.Tensor): physical view of the devices in global rank. mesh_shape (torch.Size): shape of logical view. @@ -23,6 +24,7 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) + need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. """ def __init__(self, @@ -49,8 +51,11 @@ class DeviceMesh: self.need_flatten = need_flatten if self.init_process_group: self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten: + if self.need_flatten and self._logical_mesh_id.dim() > 1: self.flatten_device_mesh = self.flatten() + # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) + self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + self.mesh_beta) @property def shape(self): @@ -64,6 +69,18 @@ class DeviceMesh: def logical_mesh_id(self): return self._logical_mesh_id + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != 'process_groups_dict': + setattr(result, k, __import__("copy").deepcopy(v, memo)) + else: + setattr(result, k, v) + + return result + def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, @@ -90,7 +107,7 @@ class DeviceMesh: def create_process_groups_for_logical_mesh(self): ''' This method is used to initialize the logical process groups which will be used in communications - among logical device mesh. + among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' @@ -186,3 +203,38 @@ class DeviceMesh: penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + + +class FlattenDeviceMesh(DeviceMesh): + + def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): + super().__init__(physical_mesh_id, + mesh_shape, + mesh_alpha, + mesh_beta, + init_process_group=False, + need_flatten=False) + # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars + self.mesh_alpha = max(self.mesh_alpha) + self.mesh_beta = min(self.mesh_beta) + # Different from original process_groups_dict, rank_list is not stored + self.process_number_dict = self.create_process_numbers_for_logical_mesh() + + def create_process_numbers_for_logical_mesh(self): + ''' + Build 1d DeviceMesh in column-major(0) and row-major(1) + for example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + num_devices = reduce(operator.mul, self.mesh_shape, 1) + process_numbers_dict = {} + process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() + process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() + return process_numbers_dict + + def mix_gather_cost(self, num_bytes): + num_devices = reduce(operator.mul, self.mesh_shape, 1) + return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 5693f3eac..d39fa5799 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,4 +1,4 @@ -from ._compatibility import compatibility, is_compatible_with_meta -from .graph_module import ColoGraphModule -from .passes import MetaInfoProp -from .tracer import ColoTracer, meta_trace +from ._compatibility import compatibility, is_compatible_with_meta +from .graph_module import ColoGraphModule +from .passes import MetaInfoProp, metainfo_trace +from .tracer import ColoTracer, meta_trace, symbolic_trace diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index 94387fbe0..8c0201c71 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map @@ -163,6 +163,23 @@ def meta_conv( return out +@register_meta(aten._convolution.default) +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args +): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + + @register_meta(aten.convolution_backward.default) def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask): @@ -179,6 +196,79 @@ def meta_adaptive_avg_pool2d_backward( return grad_input +# ================================ RNN ============================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn.default) +def meta_cuda_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, +): + + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn_backward.default) +def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + print(input, weight, hx, cx) + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_hx = torch.empty_like(hx) + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + return grad_input, grad_weight, grad_hx, grad_cx + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @register_meta(aten.relu.default) @@ -186,6 +276,11 @@ def meta_relu(input: torch.Tensor): return torch.empty_like(input) +@register_meta(aten.prelu.default) +def meta_prelu(input: torch.Tensor, weight: torch.Tensor): + return torch.empty_like(input) + + @register_meta(aten.hardswish.default) def meta_hardswish(input: torch.Tensor): return torch.empty_like(input) @@ -278,12 +373,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me # ================================== Misc ========================================== -#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): return input +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp +@register_meta(aten._local_scalar_dense.default) +def meta_local_scalar_dense(self: torch.Tensor): + return 0 + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp @register_meta(aten.where.self) def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): @@ -317,7 +418,7 @@ def meta_index_Tensor(self, indices): indices = result assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace - import torch._refs as refs # avoid import cycle in mypy + import torch._refs as refs indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 684028c01..492ebf918 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,14 +1,37 @@ -import colossalai +from typing import Any, Callable, Dict, Iterable, List, Tuple + import torch -from typing import List, Callable, Any, Tuple, Dict, Iterable + +import colossalai try: - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg CODEGEN_AVAILABLE = True except: - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name + from torch.fx.graph import ( + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_args, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: @@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks(): return (x.device, x.cpu()) else: return x - + def pack_hook_no_input(self, x): if getattr(x, "offload", True): return (x.device, x.cpu()) @@ -48,11 +71,9 @@ def pack_hook_no_input(self, x): def _gen_save_tensors_hooks_context(offload_input=True) -> str: """Generate customized saved_tensors_hooks - Args: - offload_input (bool, optional): whether we need offload input, if offload_input=False, + offload_input (bool, optional): whether we need offload input, if offload_input=False, we will use self.pack_hook_no_input instead. Defaults to True. - Returns: str: generated context """ @@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if hasattr(node, 'activation_checkpoint'): - act_ckpt_label = node.activation_checkpoint + if 'activation_checkpoint' in node.meta: + act_ckpt_label = node.meta['activation_checkpoint'] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]): current_region = act_ckpt_label start = idx end = -1 - elif current_region is not None and not hasattr(node, 'activation_checkpoint'): + elif current_region is not None and not 'activation_checkpoint' in node.meta: # used to check the case below # node ckpt states = [ckpt, ckpt, non-ckpt] end = idx - 1 @@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]): def _find_offload_regions(nodes: List[Node]): """This function is to find the offload regions - In pofo algorithm, during annotation, we will annotate the offload region with the + In pofo algorithm, during annotation, we will annotate the offload region with the list in the form of [idx, offload_input, offload_bar]. idx indicates the offload region's index, offload_input is a bool type indicates whether we need to offload the input, offload_bar is a bool type indicates whether we need to offload all the @@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): - act_offload_label = node.activation_offload + if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): + act_offload_label = node.meta['activation_offload'] if current_region == None: current_region = act_offload_label @@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen def _end_of_ckpt(node: Node, check_idx: int) -> bool: """Check if the node could end the ckpt region - Args: node (Node): torch.fx.Node - check_idx (int): the index of checkpoint level for + check_idx (int): the index of checkpoint level for nested checkpoint - Returns: bool """ - if hasattr(node, "activation_checkpoint"): - if isinstance(node.activation_checkpoint, list): - return node.activation_checkpoint[check_idx] == None + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], list): + return node.meta['activation_checkpoint'][check_idx] == None else: return False else: @@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool: def _find_nested_ckpt_regions(nodes, check_idx=0): """ - Find the nested checkpoint regions given a list of consecutive nodes. The outputs + Find the nested checkpoint regions given a list of consecutive nodes. The outputs will be list of tuples, each tuple is in the form of (start_index, end_index). """ ckpt_regions = [] @@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): current_region = None for idx, node in enumerate(nodes): - if hasattr(node, 'activation_checkpoint'): - if isinstance(getattr(node, 'activation_checkpoint'), int): - act_ckpt_label = node.activation_checkpoint + if 'activation_checkpoint' in node.meta: + if isinstance(node.meta['activation_checkpoint'], int): + act_ckpt_label = node.meta['activation_checkpoint'] else: - act_ckpt_label = node.activation_checkpoint[check_idx] + act_ckpt_label = node.meta['activation_checkpoint'][check_idx] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -287,7 +306,6 @@ def emit_ckpt_func(body, level=0, in_ckpt=False): """Emit ckpt fuction in nested way - Args: body: forward code, in recursive calls, this part will be checkpoint functions code @@ -303,8 +321,8 @@ def emit_ckpt_func(body, inputs, outputs = _find_input_and_output_nodes(node_list) # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].activation_checkpoint, int): - label = node_list[0].activation_checkpoint + if isinstance(node_list[0].meta['activation_checkpoint'], int): + label = node_list[0].meta['activation_checkpoint'] ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_func.append(f'{ckpt_fn_def}\n') for node in node_list: @@ -313,7 +331,7 @@ def emit_ckpt_func(body, delete_unused_value_func(node, ckpt_func) ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) + activation_offload = node_list[0].meta.get('activation_offload', False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage += "\n" body.append(usage) @@ -322,12 +340,12 @@ def emit_ckpt_func(body, else: # label given by each layer, e.g. if you are currently at level [0, 1, 1] # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_func.append(f'{ckpt_fn_def}\n') # if there is more level to fetch - if level + 1 < len(node_list[0].activation_checkpoint): + if level + 1 < len(node_list[0].meta['activation_checkpoint']): ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -354,7 +372,7 @@ def emit_ckpt_func(body, ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func += ckpt_func_buffer - activation_offload = getattr(node_list[0], "activation_offload", False) + activation_offload = node_list[0].meta.get('activation_offload', False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' if in_ckpt: usage = ' ' + usage @@ -368,7 +386,7 @@ def emit_ckpt_func(body, delete_unused_value_func(node, ckpt_func) ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) + activation_offload = node_list[0].meta.get('activation_offload', False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' if in_ckpt: usage = ' ' + usage @@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. - Args: body: forward code ckpt_func: checkpoint functions code @@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # we need to check if the checkpoint need to offload the input start_node_idx = start_idx[label] - if hasattr(node_list[start_node_idx], 'activation_offload'): - activation_offload = node_list[start_node_idx].activation_offload + if 'activation_offload' in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta['activation_offload'] else: activation_offload = False @@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if input_node.op != "placeholder": non_leaf_input = 1 for user in input_node.users: - if hasattr(user, "activation_checkpoint"): - if user.activation_checkpoint == label: + if 'activation_checkpoint' in user.meta: + if user.meta['activation_checkpoint'] == label: if user.op == "call_module": if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace @@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE: def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. - We call this for names that reference objects external to the Graph, like functions or types. - Returns: the global name that should be used to reference 'obj' in generated source. """ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device @@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) @@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE: code = '\n'.join(' ' + line for line in code.split('\n')) fn_code = f""" {wrap_stmts} - {prologue} {code}""" return PythonCode(fn_code, globals_) @@ -851,10 +865,8 @@ else: def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. - We call this for names that reference objects external to the Graph, like functions or types. - Returns: the global name that should be used to reference 'obj' in generated source. """ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device @@ -999,7 +1011,7 @@ else: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): + if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) @@ -1040,7 +1052,6 @@ else: # in forward function fn_code = f""" {wrap_stmts} - {ckpt_func} def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: {code}""" diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py index 43ac14ec4..6f948cb2d 100644 --- a/colossalai/fx/passes/__init__.py +++ b/colossalai/fx/passes/__init__.py @@ -1,4 +1,4 @@ from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass -from .meta_info_prop import MetaInfoProp from .concrete_info_prop import ConcreteInfoProp +from .meta_info_prop import MetaInfoProp, metainfo_trace +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 4013d79f7..373d20c51 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -1,7 +1,7 @@ import torch - from torch.fx import symbolic_trace from torch.fx.node import Node + from colossalai.fx.passes.split_module import split_module @@ -9,6 +9,30 @@ def pipe_split(): pass +def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In avgnode_split_pass, simpliy split graph by node number. + """ + mod_graph = gm.graph + avg_num_node = len(mod_graph.nodes) // pp_size + accumulate_num_node = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + accumulate_num_node += 1 + if accumulate_num_node >= avg_num_node: + accumulate_num_node = 0 + pp_size -= 1 + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): """ In balanced_split_pass, we split module by the size of parameters(weights+bias). @@ -37,6 +61,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): else: with mod_graph.inserting_after(node): split_node = mod_graph.create_node('call_function', pipe_split) + if pp_size > 1: + node_counter = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == 'placeholder': + continue + elif node_counter == 0: + node_counter += 1 + else: + pp_size -= 1 + node_counter = 0 + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() return gm @@ -102,7 +141,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): return gm -def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): +def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False): # TODO(lyl): use partition IR to assign partition ID to each node. # Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph # In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node @@ -114,7 +153,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): part_idx += 1 return part_idx - split_mod = split_module(annotated_gm, None, split_callback) + split_mod = split_module(annotated_gm, None, split_callback, merge_output) split_submodules = [] for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index e38ddbdce..52000ebe5 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,7 +1,9 @@ +import math from typing import List, Set, Tuple + import torch from torch.fx import GraphModule, Node -import math + from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp __all__ = ['chen_greedy'] diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 01c3bdb35..5b8d0da9f 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,15 +1,17 @@ +import math import sys from typing import List, Tuple -from colossalai.fx.profiler.memory import calculate_fwd_in + from torch.fx import Node -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp -import math -from .linearize import linearize -from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function + from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size from colossalai.logging import get_dist_logger +from .linearize import linearize +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence + # global vairable to indicate whether the solver is failed SOLVER_FAILED = False @@ -18,7 +20,7 @@ SOLVER_FAILED = False # https://gitlab.inria.fr/hiepacs/rotor # paper link: https://hal.inria.fr/hal-02352969 def _compute_table(chain: Chain, mmax) -> Tuple: - """Returns the optimal table: a tuple containing: + """Returns the optimal table: a tuple containing: Opt[m][lmin][lmax] with lmin = 0...chain.length and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint @@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int: """Get the forward xbar of a node Args: - node (List[Node]): List of torch.fx Node, + node (List[Node]): List of torch.fx Node, indicates a node in linearized graph Returns: @@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule, # build module if module not found except ModuleNotFoundError: - import subprocess import os + import subprocess logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) this_dir = os.path.dirname(os.path.abspath(__file__)) result = subprocess.Popen( diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 191d8d67d..ab38e8cb1 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple import torch import torch.fx -from colossalai.fx._compatibility import compatibility -from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module) from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_flatten +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module + @compatibility(is_backward_compatible=True) class ConcreteInfoProp(torch.fx.Interpreter): @@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter): DIM_HIDDEN = 16 DIM_OUT = 16 model = torch.nn.Sequential( - torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_HIDDEN, DIM_OUT), ).cuda() input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") gm = symbolic_trace(model) interp = ConcreteInfoProp(gm) interp.run(input_sample) - print(interp.summary(unit='kb')) - - - output of above code is + print(interp.summary(unit='kb')) + + + output of above code is Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB @@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter): def summary(self, unit: str = 'MB') -> str: """ - Summarizes the memory and FLOPs statistics of the `GraphModule` in - tabular format. Note that this API requires the ``tabulate`` module + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module to be installed. """ # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py deleted file mode 100644 index 2e735a25d..000000000 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py +++ /dev/null @@ -1,193 +0,0 @@ -import builtins -import copy -import operator -from ast import NodeTransformer -from copy import deepcopy -from typing import List - -import torch -from torch.fx import symbolic_trace -from torch.fx.node import Node - -from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.split_module import split_module -from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec, _all_reduce, pattern_to_func_dict -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec - -shape_consistency_manager = ShapeConsistencyManager() - - -def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index): - origin_sharding_spec = origin_dict[node_index] - target_sharding_spec = input_dict[node_index][user_node_index] - return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) - - -def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data): - - comm_action = comm_actions_dict[node_index][op_data] - if isinstance(comm_action.comm_spec, CommSpec): - rst = comm_action.comm_spec.covert_spec_to_action(tensor) - else: - origin_sharding_spec = comm_action.comm_spec['src_spec'] - tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] - rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) - return rst - - -def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): - mod_graph = gm.graph - nodes = tuple(mod_graph.nodes) - - # the dict to get origin sharding spec of node - origin_node_sharding_spec_dict = {} - for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): - strategies_vector = node.strategies_vector - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) - origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( - str(node)) - - # apply the sharding spec of parameters - for node in nodes: - if node.op == 'call_module': - target_module = node.graph.owning_module.get_submodule(node.target) - for name, param in target_module.named_parameters(): - target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - if target_sharding_spec.dim_partition_dict != {}: - origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) - param_sharded = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) - else: - param_sharded = param - setattr(target_module, name, param_sharded) - comm_actions = node.best_strategy.communication_actions - for operation_data, comm_action in comm_actions.items(): - comm_spec_to_use = comm_action.comm_spec - if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - - def wrapper(param, comm_spec): - - def hook_fn(grad): - _all_reduce(grad, comm_spec) - - param.register_hook(hook_fn) - - wrapper(param_sharded, comm_spec_to_use) - - sharded_buffer_dict = {} - for name, buffer in target_module.named_buffers(): - origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) - setattr(buffer, 'sharding_spec', origin_sharding_spec) - target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) - sharded_buffer_dict[name] = buffer_sharded - - for name, buffer_sharded in sharded_buffer_dict.items(): - setattr(target_module, name, buffer_sharded.detach().clone()) - - # the dict to get input sharding specs of user node - sharding_spec_convert_dict = {} - for index, node in enumerate(nodes): - target_sharding_specs = [] - for user_node in node.strategies_vector.successor_nodes: - target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) - target_sharding_specs.append(target_sharding_spec) - sharding_spec_convert_dict[index] = target_sharding_specs - - # the dict to record comm actions of nodes - comm_actions_dict = {} - for index, node in enumerate(nodes): - comm_action_dict = {} - for op_data, comm_action in node.best_strategy.communication_actions.items(): - comm_action_dict[op_data.name] = comm_action - comm_actions_dict[index] = comm_action_dict - - # add above dicts into graph - for node in nodes: - if node.op != 'placeholder': - with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') - comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') - break - - return sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict - - -def shape_consistency_pass(gm: torch.fx.GraphModule): - mod_graph = gm.graph - nodes = tuple(mod_graph.nodes) - input_dict_node = None - origin_dict_node = None - - # mapping the node into the origin graph index - node_to_index_dict = {} - index = 0 - for node in nodes: - if node.target == 'sharding_spec_convert_dict': - input_dict_node = node - continue - if node.target == 'origin_node_sharding_spec_dict': - origin_dict_node = node - continue - if node.target == 'comm_actions_dict': - comm_actions_dict_node = node - continue - if not hasattr(node, 'best_strategy'): - continue - node_to_index_dict[node] = index - index += 1 - assert input_dict_node is not None - - # add shape consistency apply function into graph - for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': - continue - - for user_node in node.strategies_vector.successor_nodes: - user_node_index = user_node.strategies_vector.predecessor_nodes.index(node) - with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) - - origin_index_args = user_node.args.index(node) - new_args = list(user_node.args) - new_args[origin_index_args] = shape_consistency_node - user_node.args = new_args - - comm_actions = node.best_strategy.communication_actions - for op_data, comm_action in comm_actions.items(): - comm_object = node.args[comm_action.arg_index] - if op_data.type == OperationDataType.PARAM: - continue - if comm_action.comm_type == CommType.BEFORE: - with mod_graph.inserting_before(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(comm_object, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) - new_args = list(node.args) - new_args[comm_action.arg_index] = comm_spec_apply_node - node.args = new_args - elif comm_action.comm_type == CommType.AFTER: - with mod_graph.inserting_after(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(node, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) - user_list = list(node.users.keys()) - for user in user_list: - if user == comm_spec_apply_node: - continue - new_args = list(user.args) - new_args[new_args.index(node)] = comm_spec_apply_node - user.args = tuple(new_args) - # TODO: consider other OperationDataType, such as OperationDataType.OUTPUT - return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 4fab5d041..5137494ad 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple import torch import torch.fx -from colossalai.fx._compatibility import compatibility -from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp, - profile_function, profile_method, profile_module) from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_map +from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.fx.profiler import ( + GraphInfo, + activation_size, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_tmp, + profile_function, + profile_method, + profile_module, +) + @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): @@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter): DIM_HIDDEN = 16 DIM_OUT = 16 model = torch.nn.Sequential( - torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_HIDDEN, DIM_OUT), ) input_sample = torch.rand(BATCH_SIZE, DIM_IN) @@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter): interp = MetaInfoProp(gm) interp.run(input_sample) print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB - - - # output of above code is + + + # output of above code is Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP ----------- ------- --------------- ---------------- --------- --------- --------- --------- placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB @@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter): def summary(self, unit: str = 'MB') -> str: """ - Summarizes the memory and FLOPs statistics of the `GraphModule` in - tabular format. Note that this API requires the ``tabulate`` module + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module to be installed. """ # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py @@ -306,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter): ] return tabulate(node_summaries, headers=headers, stralign='right') + + +def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: + """ + MetaInfo tracing API + + Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle, + and annotate them on ``gm.graph``. + + Uses: + >>> model = ... + >>> gm = symbolic_trace(model) + >>> args = ... # sample input to the ``GraphModule`` + >>> metainfo_trace(gm, *args) + + Args: + gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo. + verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False. + unit (str, optional): The unit of memory. Defaults to "MB". + + Returns: + torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. + """ + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + interp = MetaInfoProp(gm.to(device)) + if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) + kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) + interp.propagate(*args, **kwargs) + if verbose: + interp.summary(unit) + gm.to('cpu') + del interp + return gm diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 8671855f4..bc257edc8 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -38,11 +38,11 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[torch.fx.node.Node], int], + merge_output = False, ): """ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py Creates subgraphs out of main graph - Args: m (GraphModule): Graph module to split root_m (torch.nn.Module): root nn module. Not currently used. Included @@ -52,52 +52,40 @@ def split_module( that maps a given Node instance to a numeric partition identifier. split_module will use this function as the policy for which operations appear in which partitions in the output Module. - Returns: GraphModule: the module after split. - Example: - This is a sample setup: - import torch from torch.fx.symbolic_trace import symbolic_trace from torch.fx.graph_module import GraphModule from torch.fx.node import Node from colossalai.fx.passes.split_module import split_module - class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) - def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w - # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) - # random mod partitioning partition_counter = 0 NPARTITIONS = 3 - def mod_partition(node: Node): global partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition - # split module in module with submodules module_with_submodules = split_module( my_module_traced, my_module, mod_partition ) - Output looks like this. Original graph is broken into partitions - > print(module_with_submodules) GraphModule( (submod_0): GraphModule( @@ -108,7 +96,6 @@ def split_module( ) (submod_2): GraphModule() ) - def forward(self, x, y): param = self.param submod_0 = self.submod_0(x, param, y); x = param = y = None @@ -119,10 +106,8 @@ def split_module( getitem_3 = submod_1[1]; submod_1 = None submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None return submod_2 - Output of split module is the same as output of input traced module. This is an example within a test setting: - > orig_out = my_module_traced(x, y) > submodules_out = module_with_submodules(x, y) > self.assertEqual(orig_out, submodules_out) @@ -147,6 +132,29 @@ def split_module( use_partition.inputs.setdefault(def_node.name) if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) + + def record_output( + def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] + ): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + use_partition.outputs.setdefault(def_node.name) + else: + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.outputs.setdefault(def_node.name) # split nodes into parititons for node in m.graph.nodes: @@ -155,7 +163,10 @@ def split_module( if node.op in ["placeholder"]: continue if node.op == 'output': - torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + if merge_output: + torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) + else: + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) continue partition_name = str(split_callback(node)) @@ -235,10 +246,10 @@ def split_module( for node in m.graph.nodes: if node.op == 'placeholder': if version.parse(torch.__version__) < version.parse('1.11.0'): - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type, default_value=default_value) base_mod_env[node.name].meta = node.meta.copy() @@ -278,4 +289,9 @@ def split_module( if node.op == 'output': base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 - return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + + return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index 842c9d52e..bb4f3cd6a 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,9 +1,8 @@ import torch -from typing import Dict, Set +from typing import Dict from torch.fx.node import Node, map_arg from torch.fx.graph import Graph - def get_comm_size(prev_partition, next_partition): """ Given two partitions (parent and child), @@ -32,7 +31,6 @@ def get_comm_size(prev_partition, next_partition): def get_leaf(graph: Graph): """ Given a graph, return leaf nodes of this graph. - Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG. """ @@ -57,7 +55,6 @@ def is_leaf(graph: Graph, node: Node): def get_top(graph: Graph): """ Given a graph, return top nodes of this graph. - Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph, we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG. """ @@ -100,7 +97,6 @@ def get_all_consumers(graph: Graph, node: Node): def assign_bfs_level_to_nodes(graph: Graph): """ Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes. - Example: class MLP(torch.nn.Module): def __init__(self, dim: int): @@ -110,8 +106,6 @@ def assign_bfs_level_to_nodes(graph: Graph): self.linear3 = torch.nn.Linear(dim, dim) self.linear4 = torch.nn.Linear(dim, dim) self.linear5 = torch.nn.Linear(dim, dim) - - def forward(self, x): l1 = self.linear1(x) l2 = self.linear2(x) @@ -165,10 +159,8 @@ def assign_bfs_level_to_nodes(graph: Graph): def get_node_module(node) -> torch.nn.Module: """ Find the module associated with the given node. - Args: node (torch.fx.Node): a torch.fx.Node object in the fx computation graph - Returns: torch.nn.Module: the module associated with the given node """ @@ -177,3 +169,4 @@ def get_node_module(node) -> torch.nn.Module: assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' module = node.graph.owning_module.get_submodule(node.target) return module + diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index b520ff124..8bcbde0eb 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,12 +1,18 @@ from .._compatibility import is_compatible_with_meta if is_compatible_with_meta(): - from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp from .opcount import flop_mapping from .profiler import profile_function, profile_method, profile_module + from .shard_utils import ( + calculate_bwd_time, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, + ) from .tensor import MetaTensor else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .dataflow import GraphInfo -from .memory import activation_size, is_inplace, parameter_size +from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index f7009a84a..a5e888032 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -6,7 +6,7 @@ from typing import Dict, List from torch.fx import Graph, Node from .._compatibility import compatibility -from .memory import activation_size, is_inplace +from .memory_utils import activation_size, is_inplace class Phase(Enum): @@ -29,7 +29,7 @@ class GraphInfo: placeholders saved for | | \__________ | | backward. | | \ | | | [fwd_tmp] ------> [bwd_tmp] | <----- - | | \_________ | | [bwd_tmp] marks the peak memory + | | \_________ | | [bwd_tmp] marks the peak memory | / \ \ | | in backward pass. [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- in [fwd_tmp] because | | \_____ | | @@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: Nodes should have attribute `out` indicating the output of each node. ============================================================================ Placeholder ----> p o <---- We need to keep track of grad out - |\________ | + |\________ | ↓ ↘| f --------> b |\ \_____ ↑ | \ ↘ / f f ----> b <---- Not every forward result needs to be saved for backward | \____ ↑ - ↘ ↘| + ↘ ↘| f ----> b <---- Backward can be freed as soon as it is required no more. ↘ ↗ l - ============================================================================= + ============================================================================= Args: graph (Graph): The autograd graph with nodes marked for keyword `phase`. diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py index fbb6ff624..a5387981e 100644 --- a/colossalai/fx/profiler/experimental/__init__.py +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -1,5 +1,5 @@ -from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp from .profiler import profile_function, profile_method, profile_module from .profiler_function import * from .profiler_module import * from .registry import meta_profiler_function, meta_profiler_module +from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index fbeea5128..5c545260e 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Argument, Target from ..._compatibility import compatibility -from ..memory import activation_size +from ..memory_utils import activation_size from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .registry import meta_profiler_function, meta_profiler_module @@ -27,7 +27,7 @@ class GraphInfo: placeholders saved for | | \__________ | | backward. | | \ | | | [fwd_tmp] ------> [bwd_tmp] | <----- - | | \_________ | | [bwd_tmp] marks the peak memory + | | \_________ | | [bwd_tmp] marks the peak memory | / \ \ | | in backward pass. [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- in [fwd_tmp] because | | | \_____ | | @@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int @compatibility(is_backward_compatible=True) def profile_function(target: 'Target') -> Callable: """ - Wrap a `call_function` node or `torch.nn.functional` in order to + Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. Unfortunately, backward memory cost and FLOPs are estimated results. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. - + Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu @@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable: @compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module) -> Callable: """ - Wrap a `call_module` node or `torch.nn` in order to + Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. - + Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) diff --git a/colossalai/fx/profiler/experimental/memory.py b/colossalai/fx/profiler/experimental/shard_utils.py similarity index 100% rename from colossalai/fx/profiler/experimental/memory.py rename to colossalai/fx/profiler/experimental/shard_utils.py diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py new file mode 100644 index 000000000..6ccbcb01c --- /dev/null +++ b/colossalai/fx/profiler/memory_utils.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from .._compatibility import compatibility, is_compatible_with_meta + +__all__ = ['activation_size', 'parameter_size', 'is_inplace'] + + +@compatibility(is_backward_compatible=True) +def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`. + + Returns: + int: The activation size, unit is byte. + """ + act_size = 0 + if isinstance(out, torch.Tensor): + if out.is_quantized: + act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size() + else: + act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() + elif isinstance(out, dict): + value_list = [v for _, v in out.items()] + act_size += activation_size(value_list) + elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): + for element in out: + act_size += activation_size(element) + return act_size + + +@compatibility(is_backward_compatible=True) +def parameter_size(mod: torch.nn.Module) -> int: + """Calculate parameter size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module`. + + Returns: + int: The parameter size, unit is byte. + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size + + +def is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + if is_compatible_with_meta(): + from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: + inplace = True + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + + return inplace diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 8bd972ff3..1c39dc247 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -7,6 +7,7 @@ from numbers import Number from typing import Any, Callable, List import torch +from packaging import version aten = torch.ops.aten @@ -32,7 +33,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # inputs is a list of length 3. input_shapes = [v.shape for v in inputs[1:3]] # input_shapes[0]: [batch size, input feature dimension] - # input_shapes[1]: [batch size, output feature dimension] + # input_shapes[1]: [input feature dimension, output feature dimension] assert len(input_shapes[0]) == 2, input_shapes[0] assert len(input_shapes[1]) == 2, input_shapes[1] batch_size, input_dim = input_shapes[0] @@ -188,131 +189,136 @@ def zero_flop_jit(*args): return 0 -flop_mapping = { +if version.parse(torch.__version__) >= version.parse('1.12.0'): + flop_mapping = { # gemm - aten.mm.default: matmul_flop_jit, - aten.matmul.default: matmul_flop_jit, - aten.addmm.default: addmm_flop_jit, - aten.bmm.default: bmm_flop_jit, + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, # convolution - aten.convolution.default: conv_flop_jit, - aten._convolution.default: conv_flop_jit, - aten.convolution_backward.default: conv_backward_flop_jit, + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, # normalization - aten.native_batch_norm.default: batchnorm_flop_jit, - aten.native_batch_norm_backward.default: batchnorm_flop_jit, - aten.cudnn_batch_norm.default: batchnorm_flop_jit, - aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), - aten.native_layer_norm.default: norm_flop_counter(2, 0), - aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), # pooling - aten.avg_pool1d.default: elementwise_flop_counter(1, 0), - aten.avg_pool2d.default: elementwise_flop_counter(1, 0), - aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), - aten.avg_pool3d.default: elementwise_flop_counter(1, 0), - aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), - aten.max_pool1d.default: elementwise_flop_counter(1, 0), - aten.max_pool2d.default: elementwise_flop_counter(1, 0), - aten.max_pool3d.default: elementwise_flop_counter(1, 0), - aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), - aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), - aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), - aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), - aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), - aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), - aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), - aten.embedding.default: elementwise_flop_counter(1, 0), -} + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool1d.default: elementwise_flop_counter(1, 0), + aten.max_pool2d.default: elementwise_flop_counter(1, 0), + aten.max_pool3d.default: elementwise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), + aten.embedding.default: elementwise_flop_counter(1, 0), + } -elementwise_flop_aten = [ + elementwise_flop_aten = [ # basic op - aten.add.Tensor, - aten.add_.Tensor, - aten.div.Tensor, - aten.div_.Tensor, - aten.div.Scalar, - aten.div_.Scalar, - aten.mul.Tensor, - aten.mul.Scalar, - aten.mul_.Tensor, - aten.neg.default, - aten.pow.Tensor_Scalar, - aten.rsub.Scalar, - aten.sum.default, - aten.sum.dim_IntList, - aten.mean.dim, + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, # activation op - aten.hardswish.default, - aten.hardswish_.default, - aten.hardswish_backward.default, - aten.hardtanh.default, - aten.hardtanh_.default, - aten.hardtanh_backward.default, - aten.hardsigmoid_backward.default, - aten.hardsigmoid.default, - aten.gelu.default, - aten.gelu_backward.default, - aten.silu.default, - aten.silu_.default, - aten.silu_backward.default, - aten.sigmoid.default, - aten.sigmoid_backward.default, - aten._softmax.default, - aten._softmax_backward_data.default, - aten.relu_.default, - aten.relu.default, - aten.tanh.default, - aten.tanh_backward.default, - aten.threshold_backward.default, + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, # dropout - aten.native_dropout.default, - aten.native_dropout_backward.default, -] + aten.native_dropout.default, + aten.native_dropout_backward.default, + ] + for op in elementwise_flop_aten: + flop_mapping[op] = elementwise_flop_counter(1, 0) -for op in elementwise_flop_aten: - flop_mapping[op] = elementwise_flop_counter(1, 0) + # TODO: this will be removed in future + zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.where.self, + aten.zero_.default, + aten.zeros_like.default, + ] -# TODO: this will be removed in future -zero_flop_aten = [ - aten.as_strided.default, - aten.as_strided_.default, - aten.bernoulli_.float, - aten.cat.default, - aten.clone.default, - aten.copy_.default, - aten.detach.default, - aten.expand.default, - aten.empty_like.default, - aten.new_empty.default, - aten.new_empty_strided.default, - aten.ones_like.default, - aten._reshape_alias.default, - aten.select.int, - aten.select_backward.default, - aten.squeeze.dim, - aten.slice.Tensor, - aten.slice_backward.default, - aten.split.Tensor, - aten.permute.default, - aten.t.default, - aten.transpose.int, - aten._to_copy.default, - aten.unsqueeze.default, - aten.unbind.int, - aten._unsafe_view.default, - aten.view.default, - aten.where.self, - aten.zero_.default, - aten.zeros_like.default, -] + for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit -for op in zero_flop_aten: - flop_mapping[op] = zero_flop_jit +else: + flop_mapping = {} + elementwise_flop_aten = {} + zero_flop_aten = {} diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 2fa5c41c0..c87cd4321 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map from .._compatibility import compatibility from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase -from .memory import activation_size, parameter_size +from .memory_utils import activation_size, parameter_size from .opcount import flop_mapping from .tensor import MetaTensor @@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G def pack(x): global cache, do_not_cache - if isinstance(x, FlopTensor) and not x._tensor.uuid in cache: + if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: tensor = x._tensor.detach() - tensor.uuid = x._tensor.uuid + tensor.data_ptr = x._tensor.data_ptr x._node.meta['saved_tensor'] += [tensor] if not do_not_cache: - cache.add(x._tensor.uuid) + cache.add(x._tensor.data_ptr()) return x def unpack(x): @@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G def extract_tensor(x: Any): if isinstance(x, MetaTensor): tensor = x._tensor.detach() - tensor.uuid = x._tensor.uuid + tensor.data_ptr = x._tensor.data_ptr return tensor if not isinstance(x, torch.finfo): return x @@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G @compatibility(is_backward_compatible=True) def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ - Wrap a `call_function` node or `torch.nn.functional` in order to + Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. - + Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu @@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: kwargs['inplace'] = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 do_not_cache = False meta.bwd_mem_out -= param_size @@ -342,7 +344,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_method` node - record the memory cost and FLOPs of the execution. + record the memory cost and FLOPs of the execution. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @@ -360,13 +362,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: @compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ - Wrap a `call_module` node or `torch.nn` in order to + Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. - + Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) @@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: module.inplace = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 do_not_cache = False # grad for param will not be counted diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/shard_utils.py similarity index 53% rename from colossalai/fx/profiler/memory.py rename to colossalai/fx/profiler/shard_utils.py index 2e8b5d51b..34feefb43 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -1,58 +1,18 @@ -from typing import Dict, List, Tuple, Union - import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta +from .memory_utils import activation_size if is_compatible_with_meta(): from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS -__all__ = [ - 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" -] - - -@compatibility(is_backward_compatible=True) -def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: - """Calculate activation size of a node. - - Args: - activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` - - Returns: - int: The activation size - """ - act_size = 0 - if isinstance(out, torch.Tensor): - act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() - elif isinstance(out, dict): - value_list = [v for _, v in out.items()] - act_size += activation_size(value_list) - elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): - for element in out: - act_size += activation_size(element) - return act_size - - -@compatibility(is_backward_compatible=True) -def parameter_size(mod: torch.nn.Module) -> int: - """Calculate parameter size of a node. - - Args: - mod (torch.nn.Module): The target `torch.nn.Module` - - Returns: - int: The parameter size - """ - param_size = 0 - for param in mod.parameters(): - param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() - return param_size +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] +@compatibility(is_backward_compatible=False) def calculate_fwd_in(n: Node) -> int: - """A helper function to calculate `fwd_in` + """A helper function to calculate `fwd_in` (with sharding spec) Args: n (Node): a node from the graph @@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int: Returns: fwd_in (int): the result of `fwd_in` """ + # TODO(super-dainiu): should divide the memory by sharding spec return activation_size(n.meta["fwd_in"]) +@compatibility(is_backward_compatible=False) def calculate_fwd_tmp(n: Node) -> int: - """A helper function to calculate `fwd_tmp` + """A helper function to calculate `fwd_tmp` (with sharding spec) Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. Args: @@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int: fwd_tmp (int): the result of `fwd_tmp` """ + # TODO(super-dainiu): should divide the memory by sharding spec def is_relu_like_node(n: Node) -> bool: """Check if a node is a ReLU-like node. ReLU-like nodes have the following properties: @@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int: return 0 +@compatibility(is_backward_compatible=False) def calculate_fwd_out(n: Node) -> int: - """A helper function to calculate `fwd_out` + """A helper function to calculate `fwd_out` (with sharding spec) Args: n (Node): a node from the graph @@ -117,33 +81,34 @@ def calculate_fwd_out(n: Node) -> int: fwd_out (int): the result of `fwd_out` """ + # TODO(super-dainiu): should divide the memory by sharding spec def intersect(a, b): return {k: a[k] for k in a if k in b} fwd_in = dict() for u in n.users: - fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}) - fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)}) + fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)} return activation_size(intersect(fwd_in, fwd_out)) -def is_inplace(n: Node): - """Get the inplace argument from torch.fx.Node - +def calculate_fwd_time(n: Node) -> float: + """A helper function to calculate `fwd_time` (with sharding spec) Args: - node (Node): torch.fx.Node - + n (Node): a node from the graph Returns: - bool: indicates whether this op is inplace + fwd_time (float): the result of `fwd_time` """ - inplace = False - if n.op == "call_function": - inplace = n.kwargs.get("inplace", False) - if is_compatible_with_meta(): - from .constants import ALIAS_ATEN - if n.target in ALIAS_ATEN: - inplace = True - elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["fwd_time"] - return inplace + +def calculate_bwd_time(n: Node) -> float: + """A helper function to calculate `bwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + bwd_time (float): the result of `bwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["bwd_time"] diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 3be3dd65c..43165305f 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN __all__ = ['MetaTensor'] -def set_uuid(x): +def set_data_ptr(x): if isinstance(x, torch.Tensor): - if not hasattr(x, 'uuid'): - setattr(x, 'uuid', uuid.uuid4()) + if not x.data_ptr(): + data_ptr = uuid.uuid4() + x.data_ptr = lambda: data_ptr @compatibility(is_backward_compatible=False) @@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor): if not r._tensor.is_meta: r._tensor = r._tensor.to(torch.device('meta')) # only tensor not on `meta` should be copied to `meta` - set_uuid(r._tensor) + set_data_ptr(r._tensor) return r def __repr__(self): @@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor): # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy # of the input if func in ALIAS_ATEN: - setattr(out, 'uuid', args[0].uuid) + out.data_ptr = args[0].data_ptr # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass @@ -127,3 +128,13 @@ class MetaTensor(torch.Tensor): if device is not None: result = MetaTensor(result, fake_device=device) return result + + def cpu(self, *args, **kwargs): + if self.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def cuda(self, *args, **kwargs): + if self.device.type == 'cuda': + return self.to(*args, **kwargs) + return self.to(*args, device='cuda', **kwargs) diff --git a/colossalai/fx/tracer/__init__.py b/colossalai/fx/tracer/__init__.py index 327e1510e..590555ce3 100644 --- a/colossalai/fx/tracer/__init__.py +++ b/colossalai/fx/tracer/__init__.py @@ -1,2 +1,5 @@ -from .tracer import ColoTracer -from ._meta_trace import meta_trace +from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem + +from ._meta_trace import meta_trace +from ._symbolic_trace import symbolic_trace +from .tracer import ColoTracer diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index a7f7c8159..1c5abb81d 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -1,7 +1,5 @@ -from colossalai.fx.profiler.memory import activation_size import torch -from torch.fx import Node, Graph -from torch.fx.graph import _Namespace +from torch.fx import Graph, Node from torch.utils._pytree import tree_map diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py new file mode 100644 index 000000000..bff2f6a10 --- /dev/null +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -0,0 +1,54 @@ +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from colossalai.fx import ColoGraphModule +from colossalai.fx._compatibility import compatibility + +from .tracer import ColoTracer + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, +) -> ColoGraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule`` + constructed by recording operations seen while tracing through ``root``. + + With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using + ``meta_args`` only, the tracing can be done ahead of time. + + Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the + argument's values. + + Uses: + >>> model = ... + + # if this works + >>> gm = symbolic_trace(model, concrete_args=concrete_args) + + # else try this + >>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing. + meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``. + Defaults to None. + + Returns: + ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``. + + Warnings: + This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. + + """ + graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/colossalai/fx/tracer/bias_addition_patch/__init__.py b/colossalai/fx/tracer/bias_addition_patch/__init__.py new file mode 100644 index 000000000..e724d6a22 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/__init__.py @@ -0,0 +1,2 @@ +from .patched_bias_addition_function import * +from .patched_bias_addition_module import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py new file mode 100644 index 000000000..071bde4a5 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py @@ -0,0 +1,4 @@ +from .addbmm import Addbmm +from .addmm import Addmm +from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict +from .linear import Linear diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py new file mode 100644 index 000000000..859a19bf6 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -0,0 +1,75 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addbmm) +@bias_addition_function.register(torch.addbmm) +class Addbmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.bmm + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # torch.bmm does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def insert_sum_node(self, input_proxy, sum_dims=0): + ''' + This method is used to sum the input_proxy through the sum_dims. + ''' + node_kind = 'call_function' + node_target = torch.sum + node_args = (input_proxy, sum_dims) + node_kwargs = {} + sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return sum_proxy + + def generate(self): + # The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2)) + + # doing the non-bias computation(temp_0 = torch.bmm(b1, b2)) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2]) + + # doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0)) + sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + # doing the multiplication with beta if it exists(temp_2 = beta * input) + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) + alpha_proxy = self.create_mul_node(alpha, sum_proxy) + else: + alpha_proxy = sum_proxy + + # doing the addition(temp_4 = temp_2 + temp_3) + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py new file mode 100644 index 000000000..fe7d8d07a --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -0,0 +1,60 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function, bias_addition_method +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_method.register(torch.Tensor.addmm) +@bias_addition_function.register(torch.addmm) +class Addmm(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + kwargs = {} + if 'beta' in self.kwargs: + kwargs['beta'] = self.kwargs['beta'] + if 'alpha' in self.kwargs: + kwargs['alpha'] = self.kwargs['alpha'] + return kwargs + + def transpose_other_operand_for_linear(self, other_proxy): + ''' + This method is used to transpose the other operand for linear function. + For example: + input = torch.rand(3, 4) + m1 = torch.rand(3, 5) + m2 = torch.rand(5, 4) + original_output = torch.addmm(input, m1, m2) + # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 + # before we call the linear function. + new_output = torch.linear(m1, m2.transpose(0, 1)) + input + ''' + node_kind = 'call_function' + node_target = torch.transpose + node_args = (other_proxy, 0, 1) + node_kwargs = {} + transpose_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return transpose_proxy + + def generate(self): + transpose_proxy = self.transpose_other_operand_for_linear(self.args[2]) + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) + kwargs = self.extract_kwargs_from_origin_func() + + if 'beta' in kwargs: + beta = kwargs['beta'] + beta_proxy = self.create_mul_node(self.args[0], beta) + else: + beta_proxy = self.args[0] + + if 'alpha' in kwargs: + alpha = kwargs['alpha'] + alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) + else: + alpha_proxy = non_bias_linear_func_proxy + + bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py new file mode 100644 index 000000000..8a3786332 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -0,0 +1,115 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionFunc(ABC): + """ + This class is used to construct the restructure computation graph for + call_func node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + + @abstractmethod + def extract_kwargs_from_origin_func(self): + """ + This method is used to extract the kwargs for further graph transform. + + For example: + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need + to insert two more operator.mul nodes for the computation graph to compute the + final result. + """ + pass + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_func node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use torch.addmm as an example: + The origin node is: + %addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1}) + Restructured graph is: + %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + """ + pass + + def create_mul_node(self, input_proxy, coefficent): + """ + This method is used to create a coefficent node for the numerical correctness. + The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2) + Therefore, we need to use this method insert two more operator.mul nodes for + the computation graph to compute the final result. + """ + node_kind = 'call_function' + node_target = operator.mul + node_args = ( + input_proxy, + coefficent, + ) + node_kwargs = {} + mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return mul_proxy + + +class LinearBasedBiasFunc(BiasAdditionFunc): + """ + This class is used to construct the restructure computation graph for + call_func node based on F.linear. + """ + + def create_non_bias_func_proxy(self, input_proxy, other_proxy): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + assert self.substitute_func == torch.nn.functional.linear + node_kind = 'call_function' + node_target = self.substitute_func + + node_args = (input_proxy, other_proxy) + # non-bias linear does not have any kwargs + node_kwargs = {} + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + +func_to_func_dict = { + torch.addmm: F.linear, + torch.addbmm: torch.bmm, + F.linear: F.linear, +} + +method_to_func_dict = { + torch.Tensor.addmm: F.linear, + torch.Tensor.addbmm: torch.bmm, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py new file mode 100644 index 000000000..e11ec0a36 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -0,0 +1,25 @@ +import operator + +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_function +from .bias_addition_function import LinearBasedBiasFunc + + +@bias_addition_function.register(F.linear) +class Linear(LinearBasedBiasFunc): + + def extract_kwargs_from_origin_func(self): + assert 'bias' in self.kwargs + kwargs = {} + if 'bias' in self.kwargs: + kwargs['bias'] = self.kwargs['bias'] + return kwargs + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) + kwargs = self.extract_kwargs_from_origin_func() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py new file mode 100644 index 000000000..f3823bb3e --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py @@ -0,0 +1,3 @@ +from .bias_addition_module import * +from .conv import * +from .linear import * diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py new file mode 100644 index 000000000..85f1553e3 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -0,0 +1,111 @@ +import operator +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + + +class BiasAdditionModule(ABC): + """ + This class is used to construct the restructure computation graph for + call_module node with bias addition inside. + """ + + def __init__(self, tracer, target, args, kwargs, substitute_func): + self.tracer = tracer + self.target = target + self.args = args + self.kwargs = kwargs + self.substitute_func = substitute_func + self.weight_proxy = self._create_weight_proxy() + self.bias_proxy = self._create_bias_proxy() + + def _create_weight_proxy(self): + """ + Create weight proxy, the node created by this proxy contains module weight. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + weight_node_kind = 'get_attr' + weight_node_target = self.target + '.weight' + weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) + return weight_proxy + + def _create_bias_proxy(self): + """ + Create bias proxy, the node created by this proxy contains module bias. + + Note: this function will be invoked during module initializing, + you should never call this function. + """ + bias_node_kind = 'get_attr' + bias_node_target = self.target + '.bias' + bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) + return bias_proxy + + @abstractmethod + def extract_kwargs_from_mod(self): + """ + This method is used to extract the kwargs for non-bias computation. + + For example: + The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are + considered during module initilizing. However, we need to consider those attributes as kwargs + in F.conv2d. + """ + pass + + def create_non_bias_func_proxy(self, input_proxy=None): + """ + This method is used to create the non_bias_func proxy, the node created by this proxy will + compute the main computation, such as convolution, with bias option banned. + """ + node_kind = 'call_function' + node_target = self.substitute_func + if input_proxy is None: + input_proxy = self.args[0] + node_args = (input_proxy, self.weight_proxy) + node_kwargs = self.extract_kwargs_from_mod() + non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs) + return non_bias_func_proxy + + def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): + """ + This method is used to create the bias_addition_proxy, the node created by this proxy will + compute the sum of non_bias_func result and bias with some reshape operation if needed. + """ + bias_add_node_kind = 'call_function' + bias_add_node_target = operator.add + bias_add_args = (non_bias_func_proxy, bias_proxy) + bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) + return bias_add_proxy + + @abstractmethod + def generate(self): + """ + This method is used to construct the whole restructure computation graph for call_module node with bias + addition inside. + + A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node, + a bias reshape node if needed and a bias addition node. + + Use Conv2d module as an example: + The origin node is: + %conv: call_module[target=conv](args = (%x,), kwargs = {}) + Restructured graph is: + %conv_weight : [#users=1] = get_attr[target=conv.weight] + %conv_bias : [#users=1] = get_attr[target=conv.bias] + %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + """ + pass + + +module_to_func_dict = { + torch.nn.Linear: F.linear, + torch.nn.Conv1d: F.conv1d, + torch.nn.Conv2d: F.conv2d, + torch.nn.Conv3d: F.conv3d, +} diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py new file mode 100644 index 000000000..4b6c82a74 --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Conv1d) +@bias_addition_module.register(torch.nn.Conv2d) +@bias_addition_module.register(torch.nn.Conv3d) +class BiasAdditionConv(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + root = self.tracer.root + conv_module = root.get_submodule(self.target) + kwarg_attributes = ['groups', 'dilation', 'stride'] + non_bias_kwargs = {} + for attr_name in kwarg_attributes: + if hasattr(conv_module, attr_name): + non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) + if conv_module.padding_mode != "zeros": + #TODO: non zeros mode requires some extra processing for input + conv_type = type(conv_module) + if conv_type == "torch.nn.Conv1d": + padding_element = _single(0) + elif conv_type == "torch.nn.Conv2d": + padding_element = _pair(0) + elif conv_type == "torch.nn.Conv3d": + padding_element = _triple(0) + non_bias_kwargs['padding'] = padding_element + else: + non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + + return non_bias_kwargs + + def create_bias_reshape_proxy(self, dimensions): + """ + This method is used to reshape the bias node in order to make bias and + output of non-bias convolution broadcastable. + """ + bias_shape = [1] * (dimensions - 1) + bias_shape[0] = -1 + bias_reshape_node_kind = 'call_method' + bias_reshape_node_target = 'view' + bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) + bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, + bias_reshape_node_args, {}) + return bias_reshape_proxy + + def generate(self): + non_bias_conv_func_proxy = self.create_non_bias_func_proxy() + output_dims = non_bias_conv_func_proxy.meta_data.dim() + bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py new file mode 100644 index 000000000..f6f7b6dda --- /dev/null +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +from ...registry import bias_addition_module +from .bias_addition_module import BiasAdditionModule + + +@bias_addition_module.register(torch.nn.Linear) +class BiasAdditionLinear(BiasAdditionModule): + + def extract_kwargs_from_mod(self): + return {} + + def generate(self): + non_bias_linear_func_proxy = self.create_non_bias_func_proxy() + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy) + return bias_addition_proxy diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py new file mode 100644 index 000000000..6fee5f5d0 --- /dev/null +++ b/colossalai/fx/tracer/experimental.py @@ -0,0 +1,642 @@ +import enum +import functools +import operator +import inspect +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from torch.fx import Graph, Node, Proxy, Tracer +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta +from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from colossalai.fx.tracer.registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +Target = Union[Callable[..., Any], str] +Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + 'Node',]] +_CScriptMethod = ['add', 'mul', 'sub', 'div'] +_TorchNewMethod = [ + "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", + "finfo" +] +_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] + + +def _truncate_suffix(s: str): + import re + return re.sub(r'_\d+$', '', s) + + +def default_device(): + return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + +@compatibility(is_backward_compatible=False) +class ColoProxy(Proxy): + + def __init__(self, *args, data=None, **kwargs): + super().__init__(*args, **kwargs) + self._meta_data = data + + @property + def meta_data(self): + return self._meta_data + + @meta_data.setter + def meta_data(self, args): + wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x + self._meta_data = tree_map(wrap_fn, args) + + @classmethod + def __torch_function__(cls, orig_method, types, args=(), kwargs=None): + proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + kwargs = {} if kwargs is None else kwargs + if proxy.meta_data is None: + proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + return proxy + + @classmethod + def from_torch_proxy(cls, proxy: Proxy): + return cls(proxy.node, proxy.tracer) + + def __repr__(self): + return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" + + def __len__(self): + return len(self.meta_data) + + def __int__(self): + return int(self.meta_data) + + def __index__(self): + try: + return int(self.meta_data) + except: + return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() + + def __float__(self): + return float(self.meta_data) + + def __bool__(self): + return self.meta_data + + def __getattr__(self, k): + return ColoAttribute(self, k, getattr(self._meta_data, k, None)) + + def __setitem__(self, key, value): + proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy.meta_data = self._meta_data + return proxy + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + def __isinstancecheck__(self, type): + return isinstance(self.meta_data, type) + + @property + def shape(self): + return self.meta_data.shape + + @property + def ndim(self): + return self.meta_data.ndim + + @property + def device(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy.meta_data = self.meta_data.device + return proxy + + @property + def dtype(self): + proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy.meta_data = self.meta_data.dtype + return proxy + + def to(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + + def cpu(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + + def cuda(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + + +@compatibility(is_backward_compatible=False) +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str, data=None): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._meta_data = data + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + def __repr__(self): + return f"ColoAttribute({self.node.name}, attr={self.attr})" + + +@compatibility(is_backward_compatible=False) +class ColoTracer(Tracer): + + def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._disable_module_getattr = False + self.proxy_buffer_attributes = True + + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + + def proxy(self, node: Node) -> 'ColoProxy': + return ColoProxy(node, self) + + def create_proxy(self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + + proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + if kind == 'placeholder': + proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + proxy.meta_data = attr_itr + finally: + self._disable_module_getattr = False + elif kind == 'call_function': + proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + self._disable_module_getattr = True + try: + if target == '__call__': + proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = False + elif kind == 'call_module': + mod = self.root.get_submodule(target) + self._disable_module_getattr = True + try: + proxy.meta_data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + finally: + self._disable_module_getattr = False + return proxy + + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + node.meta['activation_checkpoint'] = self.act_ckpt_region_count + return node + + def trace(self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names + concrete_arg_names = set(concrete_args.keys()) + non_concrete_arg_names = sig_names - concrete_arg_names + + def _check_arg_name_valid(names): + success, element = is_element_in_list(names, sig_names) + if not success: + raise KeyError( + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + + _check_arg_name_valid(meta_arg_names) + _check_arg_name_valid(concrete_arg_names) + + self.concrete_args = concrete_args + self.meta_args = meta_args + + with _TorchTensorOverride(self), self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): + self.graph = super().trace(root, concrete_args=concrete_args) + self.graph.lint() + return self.graph + + + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + + def _post_check(self, non_concrete_arg_names: Set[str]): + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + self.graph.lint() + + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + + def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: + kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else + lambda node: ColoProxy(self, node, n, attr_val)) + val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), + parameter_proxy_cache) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + return attr_val + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, +) -> ColoGraphModule: + if is_compatible_with_meta(): + if meta_args is not None: + root.to(default_device()) + wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x + graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)) + root.cpu() + else: + graph = Tracer().trace(root, concrete_args=concrete_args) + else: + from .tracer import ColoTracer as OrigColoTracer + graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) + + +@compatibility(is_backward_compatible=False) +class _TorchTensorOverride(object): + + def __init__(self, tracer: Tracer): + self.overrides = {} + self.tracer = tracer + + def __enter__(self): + + def wrap_tensor_method(target): + + @functools.wraps(target) + def wrapper(*args, **kwargs): + is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( + isinstance(p, ColoProxy) for p in kwargs.values()) + if is_proxy: + # if the arg is a proxy, then need to record this function called on this proxy + # e.g. torch.ones(size) where size is an input proxy + self.tracer._disable_module_getattr = True + try: + proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + finally: + self.tracer._disable_module_getattr = False + return proxy + else: + return target(*args, **kwargs) + + return wrapper, target + + self.overrides = { + target: wrap_tensor_method(getattr(torch, target)) + for target in _TorchNewMethod + if callable(getattr(torch, target)) + } + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, wrapper) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, orig) + + +def meta_prop_pass(gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None): + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + for node in gm.graph.nodes: + node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, + node.kwargs) + +def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + if kind == 'placeholder': + meta_out = meta_args[target] if target in meta_args else concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + meta_out = attr_itr + elif kind == 'call_function': + meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + if target == '__call__': + meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_module': + mod = root.get_submodule(target) + meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + else: + meta_out = None + return meta_out + +def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): + if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: + meta_out = meta_args[target] + return meta_out + + if target in [getattr(torch, torch_func) for torch_func in _TorchNewMethod]: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + args_metas = tree_map(unwrap_fn, args) + kwargs_metas = tree_map(unwrap_fn, kwargs) + + if kind == "call_function": + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_module": + mod = root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = mod(*args_metas, **kwargs_metas) + elif kind == "get_attr": + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + + +def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None): + result_graph = Graph() + value_remap = {} + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + + for orig_node in gm.graph.nodes: + assert hasattr(orig_node, "_meta_data") + kind = orig_node.op + target = orig_node.target + args = orig_node.args + kwargs = orig_node.kwargs + + args_metas = tree_map(unwrap_fn, args) + tracer = ColoTracer() + tracer.graph = Graph(tracer_cls=ColoTracer) + tracer.root = root_model + + def wrap_fn(n): + if isinstance(n, Node): + proxy = ColoProxy(n, tracer) + proxy.meta_data = n._meta_data + return proxy + return n + + args_proxy = tree_map(wrap_fn, args) + kwargs_proxy = tree_map(wrap_fn, kwargs) + + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + elif kind == "call_module": + # if not hasattr(self, "orig_forward"): + # raise AttributeError(f"{self} does not have an attribute called orig_forward") + mod = gm.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + if handle is not None: + handle.generate() + for node_inserted in tracer.graph.nodes: + value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n]) + last_node = value_remap[node_inserted] + value_remap[orig_node] = last_node + else: + value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n]) + + del tracer + + gm.graph = result_graph + gm.recompile() + meta_prop_pass(gm, root_model, meta_args) + diff --git a/colossalai/fx/tracer/meta_patch/__init__.py b/colossalai/fx/tracer/meta_patch/__init__.py index 28b54b9bb..192aef7a4 100644 --- a/colossalai/fx/tracer/meta_patch/__init__.py +++ b/colossalai/fx/tracer/meta_patch/__init__.py @@ -1,3 +1,2 @@ -from .registry import * from .patched_function import * from .patched_module import * diff --git a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py index a40ca4c39..e00fdf6f5 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py @@ -1,7 +1,6 @@ from .activation_function import * from .arithmetic import * +from .convolution import * from .embedding import * from .normalization import * -from .python_ops import * from .torch_ops import * -from .convolution import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py index d710098c7..12c425148 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -1,7 +1,8 @@ import torch -from ..registry import meta_patched_function + +from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.relu) def torch_nn_func_relu(input, inplace=False): - return torch.empty(input.shape, device='meta') \ No newline at end of file + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 3e697de86..042b92c58 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -1,6 +1,6 @@ import torch -from ..registry import meta_patched_function +from ...registry import meta_patched_function @meta_patched_function.register(torch.matmul) @@ -57,16 +57,36 @@ def torch_bmm(input, mat2, *, out=None): return torch.empty(batch_size, n, p, device="meta") +@meta_patched_function.register(torch.nn.functional.linear) +def torch_linear(input, mat2, bias=None, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + output_shape = list(input.shape) + output_feature = list(mat2.shape)[0] + output_shape[-1] = output_feature + return torch.empty(*output_shape, device="meta") + + @meta_patched_function.register(torch.addbmm) @meta_patched_function.register(torch.Tensor.addbmm) def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): if out is not None: raise ValueError("Don't support in-place abs for MetaTensor analysis") - batch_size, n, m = mat1.shape + _, n, _ = mat1.shape _, _, p = mat2.shape return torch.empty(n, p, device="meta") +@meta_patched_function.register(torch.addmm) +@meta_patched_function.register(torch.Tensor.addmm) +def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + n, _ = mat1.shape + _, p = mat2.shape + return torch.empty(n, p, device="meta") + + @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): assert out is None, 'saving to out is not supported yet' diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py index eb88f2451..8500e5c82 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -1,8 +1,10 @@ -import torch import collections -from itertools import repeat -from ..registry import meta_patched_function import math +from itertools import repeat + +import torch + +from ...registry import meta_patched_function def _ntuple(n, name="parse"): diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py index 42fb359b5..6d8d864ea 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_function + +from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.embedding) @@ -10,4 +11,4 @@ def torch_nn_functional_embedding(input, norm_type=2.0, scale_grad_by_freq=False, sparse=False): - return torch.empty(*input.shape, weight.shape[-1], device="meta") \ No newline at end of file + return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py index 80d034f9a..e9e7eda61 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_function + +from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.layer_norm) @@ -16,4 +17,4 @@ def torch_nn_func_batchnorm(input, training=False, momentum=0.1, eps=1e-05): - return torch.empty(input.shape, device='meta') \ No newline at end of file + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py index 72cd43674..4c171cb10 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -1,8 +1,11 @@ import operator + import torch -from ..registry import meta_patched_function + from colossalai.fx.proxy import ColoProxy +from ...registry import meta_patched_function + @meta_patched_function.register(operator.getitem) def operator_getitem(a, b): diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index 229443ed9..b14ff10ce 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_function + +from ...registry import meta_patched_function @meta_patched_function.register(torch.arange) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py index ed572e3b7..d03da6588 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.ReLU) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py index 32bf1b8da..cf9f3487a 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -1,6 +1,8 @@ import math + import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Conv1d) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py index 705d37735..999e33b17 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -1,8 +1,9 @@ import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Embedding) def torch_nn_embedding(self, input): result_shape = input.shape + (self.embedding_dim,) - return torch.empty(result_shape, device='meta') \ No newline at end of file + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py index 0275f134d..56f13bf97 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Linear) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index e83b31b67..c21ff64cf 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -1,5 +1,6 @@ import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.LayerNorm) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py index f740f8511..7ce23fbf7 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -1,6 +1,8 @@ import math + import torch -from ..registry import meta_patched_module + +from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.AvgPool1d) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py index 15a0be417..ee15ca341 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -1,7 +1,9 @@ -import torch -from ..registry import meta_patched_module from typing import Optional +import torch + +from ...registry import meta_patched_module + @meta_patched_module.register(torch.nn.GRU) @meta_patched_module.register(torch.nn.RNN) diff --git a/colossalai/fx/tracer/meta_patch/registry.py b/colossalai/fx/tracer/registry.py similarity index 71% rename from colossalai/fx/tracer/meta_patch/registry.py rename to colossalai/fx/tracer/registry.py index 3eeafe448..12fc6de73 100644 --- a/colossalai/fx/tracer/meta_patch/registry.py +++ b/colossalai/fx/tracer/registry.py @@ -23,3 +23,6 @@ class PatchRegistry: meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') +bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') +bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') +bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index bccdbf2ce..1ae31f958 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -1,26 +1,33 @@ #!/usr/bin/env python """ -tracer.py: +tracer.py: Implemented a tracer which supports control flow and user-defined meta arguments. The implementation is partly inspired HuggingFace's fx tracer """ import enum -import inspect import functools +import inspect import operator from contextlib import contextmanager -from colossalai.fx.tracer.meta_patch import meta_patched_module +from typing import Any, Dict, Optional + import torch import torch.nn as nn from torch import Tensor -from torch.fx import Tracer, Node -from torch.fx.graph import Graph -from torch.fx.proxy import Proxy, ParameterProxy +from torch.fx import Node, Tracer +from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods +from torch.fx.proxy import ParameterProxy, Proxy + from ..proxy import ColoProxy -from typing import Optional, Dict, Any -from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy -from .meta_patch import meta_patched_function, meta_patched_module -from torch.fx.graph import magic_methods, reflectable_magic_methods +from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list +from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from .registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) __all__ = ['ColoTracer'] @@ -77,88 +84,69 @@ class ColoTracer(Tracer): """ Create a proxy for different kinds of operations. """ - proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) if self.tracer_type == TracerType.DEFAULT: # since meta_args is not given # we just fall back to the original torch.fx.Tracer + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) return proxy + # if graph is traced for auto parallelism module, some extra node will be added during + # graph construction to deal with the compatability between bias addition and all reduce. + + # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function + # to create node on computation graph + origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + # dispatch the arguments generator depending on the kind and target in origin arguments. + args_metas, _ = extract_meta(*args, **kwargs) + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute) + + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute) + finally: + self._disable_module_getattr = False + + if handle is not None: + return handle.generate() + + # create nodes using patched arguments + proxy = super().create_proxy(*origin_arguments) proxy: ColoProxy + meta_out = self._meta_data_computing( + kind, + target, + args, + kwargs, + ) + proxy.meta_data = meta_out - if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: - proxy.meta_data = self.meta_args[target] - return proxy - - if target in self.orig_torch_tensor_methods: - # NOTE: tensor constructors in PyTorch define the `device` argument as - # *kwargs-only*. That is why this works. If you add methods to - # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, - # this will break and you will likely see issues where we cannot infer - # the size of the output. - if "device" in kwargs: - kwargs["device"] = "meta" - - try: - args_metas, kwargs_metas = extract_meta(*args, **kwargs) - - if kind == "call_function": - # fetch patched function - if meta_patched_function.has(target): - meta_target = meta_patched_function.get(target) - elif meta_patched_function.has(target.__name__): - # use name for some builtin op like @ (matmul) - meta_target = meta_patched_function.get(target.__name__) - else: - meta_target = target - - meta_out = meta_target(*args_metas, **kwargs_metas) - if isinstance(meta_out, torch.Tensor): - meta_out = meta_out.to(device="meta") - elif kind == "call_method": - method = getattr(args_metas[0].__class__, target) - - # fetch patched method - if meta_patched_function.has(method): - meta_target = meta_patched_function.get(method) - else: - meta_target = method - - meta_out = meta_target(*args_metas, **kwargs_metas) - elif kind == "call_module": - if not hasattr(self, "orig_forward"): - raise AttributeError(f"{self} does not have an attribute called orig_forward") - self._disable_module_getattr = True - try: - mod = self.root.get_submodule(target) - mod_type = type(mod) - if meta_patched_module.has(mod_type): - meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) - else: - meta_out = self.orig_forward(*args_metas, **kwargs_metas) - finally: - self._disable_module_getattr = False - elif kind == "get_attr": - self._disable_module_getattr = True - try: - attr_itr = self.root - atoms = target.split(".") - for atom in atoms: - attr_itr = getattr(attr_itr, atom) - if isinstance(attr_itr, torch.Tensor): - meta_out = attr_itr.to(device="meta") - else: - meta_out = attr_itr - finally: - self._disable_module_getattr = False - else: - return proxy - - if not isinstance(proxy, Proxy): - raise ValueError("Don't support composite output yet") - proxy.meta_data = meta_out - except Exception as e: - raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") return proxy def _module_getattr(self, attr, attr_val, parameter_proxy_cache): @@ -222,6 +210,105 @@ class ColoTracer(Tracer): else: raise ValueError(f"Unrecognised tracer type {tracer_type}") + def _meta_data_computing(self, kind, target, args, kwargs): + + if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: + meta_out = self.meta_args[target] + return meta_out + + if target in self.orig_torch_tensor_methods: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + if kind == "call_function": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], + torch.nn.parameter.Parameter): + convert_to_parameter = True + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) + + elif kind == "call_method": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter): + convert_to_parameter = True + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) + elif kind == "call_module": + if not hasattr(self, "orig_forward"): + raise AttributeError(f"{self} does not have an attribute called orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + finally: + self._disable_module_getattr = False + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + def trace(self, root: nn.Module, concrete_args: Optional[Dict[str, Tensor]] = None, @@ -231,7 +318,7 @@ class ColoTracer(Tracer): Args: root (nn.Module): a `nn.Module` object to trace the computation graph - meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph. + meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph. These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors. concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies. """ @@ -383,7 +470,7 @@ class ColoTracer(Tracer): if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - setattr(node, 'activation_checkpoint', self.act_ckpt_region_count) + node.meta['activation_checkpoint'] = self.act_ckpt_region_count return node diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index a82640d67..7a5a44ebb 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,6 +1,9 @@ -from .chunk import TensorInfo, TensorState +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration +from .gemini_mgr import GeminiManager from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory -from .gemini_mgr import GeminiManager -__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState'] +__all__ = [ + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', + 'search_chunk_configuration' +] diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py index 86ff785f7..6914d2dbe 100644 --- a/colossalai/gemini/chunk/__init__.py +++ b/colossalai/gemini/chunk/__init__.py @@ -1,4 +1,6 @@ from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState from .manager import ChunkManager -from .search_utils import clasify_params, search_chunk_configuration +from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager + +__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index 648d48ec5..a7682eaf6 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -1,552 +1,576 @@ -import torch -import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Dict, List - -from colossalai.utils import get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup - - -class TensorState(Enum): - FREE = 0 - COMPUTE = 1 - HOLD = 2 - HOLD_AFTER_BWD = 3 - READY_FOR_REDUCE = 4 - - -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) - - -@dataclass -class TensorInfo: - state: TensorState - offset: int - end: int - - -class ChunkFullError(Exception): - pass - - -def is_storage_empty(tensor: torch.Tensor) -> bool: - return tensor.storage().size() == 0 - - -def free_storage(tensor: torch.Tensor) -> None: - if not is_storage_empty(tensor): - tensor.storage().resize_(0) - - -def alloc_storage(tensor: torch.Tensor) -> None: - if is_storage_empty(tensor): - tensor.storage().resize_(tensor.numel()) - - -class Chunk: - - _total_number = 0 - - def __init__(self, - chunk_size: int, - process_group: ColoProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - keep_gathered: bool = False, - pin_memory: bool = False) -> None: - """ - Chunk: A container owning a piece of contiguous memory space for tensors - Here we use all-gather operation to gather the whole chunk. - Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. - It is designed to make the full use of communication and PCIE bandwidth. - - Args: - chunk_size (int): the number of elements in the chunk - process_group (ColoProcessGroup): the process group of this chunk - dtype (torch.dtype): the data type of the chunk - init_device (torch.device): optional, the device where the tensor is initialized - The default value is None, which is the current GPU - keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory - pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory - """ - self.count_id = Chunk._total_number - Chunk._total_number += 1 - - self.chunk_size = chunk_size - self.utilized_size = 0 - # Here, we use torch process group, - # since ColoProcessGroup might get deprecated soon - self.torch_pg = process_group.dp_process_group() - self.pg_size = dist.get_world_size(self.torch_pg) - self.pg_rank = dist.get_rank(self.torch_pg) - - # the chunk size should be able to be divied by the size of GPU - if not keep_gathered: - assert chunk_size % self.pg_size == 0 - self.shard_size = chunk_size // self.pg_size - self.shard_begin = self.shard_size * self.pg_rank - self.shard_end = self.shard_begin + self.shard_size - self.valid_end = self.shard_size - - self.dtype = dtype - device = init_device or get_current_device() - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.chunk_total = None # we force chunk_total located in CUDA - self.cuda_shard = None # using two attributes for the better interpretation - self.cpu_shard = None - self.is_gathered = True - - self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() - self.shard_mem = self.chunk_mem // self.pg_size - - # each tensor is associated with a TensorInfo to track meta info - self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} - # the total number of all tensors - self.num_tensors = 0 - # monitor the states of all tensors - self.tensors_state_monitor: Dict[TensorState, int] = dict() - for state in TensorState: - self.tensors_state_monitor[state] = 0 - - # some chunks can keep gathered all the time - # so their computation patterns are the same as that of the parameters in DDP - self.keep_gathered = keep_gathered - if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin - - # if pin_memory is True, we allocate a piece of CPU pin-memory - # for it all the time - self.pin_memory = pin_memory - - # we introduce the paired chunk here - # it refers to another chunk having the same parameters - # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk - self.paired_chunk = None - # if this chunk is synchronized with the optimizer, the flag is True - self.optim_sync_flag = True - # if the cpu_shard has been visited during the training step, the flag is True - self.cpu_vis_flag = False - - @property - def memory_usage(self) -> Dict[str, int]: - cuda_memory = 0 - cpu_memory = 0 - - if self.chunk_temp is not None: - # this chunk is not closed - if self.chunk_temp.device.type == 'cuda': - cuda_memory += self.chunk_mem - else: - cpu_memory += self.chunk_mem - else: - if self.is_gathered: - cuda_memory += self.chunk_mem - if self.cuda_shard is not None: - cuda_memory += self.shard_mem - if self.cpu_shard is not None: - cpu_memory += self.shard_mem - - return dict(cuda=cuda_memory, cpu=cpu_memory) - - @property - def device_type(self) -> str: - if self.chunk_temp is not None: - return self.chunk_temp.device.type - else: - if self.is_gathered: - return 'cuda' - elif self.cuda_shard is not None: - return 'cuda' - else: - return 'cpu' - - @property - def payload(self) -> torch.Tensor: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_total - elif self.cuda_shard is not None: - return self.cuda_shard - else: - return self.cpu_shard - - @property - def payload_mem(self) -> int: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_mem - else: - return self.shard_mem - - @property - def can_move(self) -> bool: - return not self.is_gathered - - @property - def can_release(self) -> bool: - if self.keep_gathered: - return False - else: - return self.tensors_state_monitor[TensorState.HOLD] + \ - self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors - - @property - def can_reduce(self): - return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors - - @property - def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values in CUDA. - """ - if self.is_gathered: - valid_tensor = self.chunk_total[:self.utilized_size] - else: - assert self.cuda_shard is not None # only check in CUDA - valid_tensor = self.cuda_shard[:self.valid_end] - - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() - - def append_tensor(self, tensor: torch.Tensor): - """Add a tensor to the chunk. - - Args: - tensor (torch.Tensor): a tensor to be added to the chunk - """ - # sanity check - assert self.chunk_temp is not None - assert tensor.dtype == self.dtype - - new_utilized_size = self.utilized_size + tensor.numel() - # raise exception when the chunk size is exceeded - if new_utilized_size > self.chunk_size: - raise ChunkFullError - - self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) - assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) - - # record all the information about the tensor - self.num_tensors += 1 - tensor_state = TensorState.HOLD - self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) - self.tensors_state_monitor[tensor_state] += 1 - self.utilized_size = new_utilized_size - - def close_chunk(self, shard_dev: Optional[torch.device] = None): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - - Args: - shard_dev: the device where the shard locates - """ - # sanity check - assert self.chunk_temp is not None - - # calculate the valid end for each shard - if self.utilized_size <= self.shard_begin: - self.valid_end = 0 - elif self.utilized_size < self.shard_end: - self.valid_end = self.utilized_size - self.shard_begin - - if self.chunk_temp.device.type == 'cpu': - self.chunk_total = self.chunk_temp.to(get_current_device()) - self.__update_tensors_ptr() - else: - self.chunk_total = self.chunk_temp - self.chunk_temp = None - - self.__scatter() - - if self.keep_gathered: - if shard_dev is None: - shard_dev = get_current_device() - else: - assert shard_dev.type == 'cuda' - elif shard_dev is None: - shard_dev = torch.device('cpu') - - if self.pin_memory or shard_dev.type == 'cpu': - self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) - self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited - - if shard_dev.type == 'cpu': - self.cuda_shard = None - - def shard_move(self, device: torch.device, force_copy: bool = False): - """Move the shard tensor in the chunk. - - Args: - device: the device to which the shard will move - force_copy: if True, copy function is called mandatorily - """ - # sanity check - assert not self.is_gathered - # when the current chunk is not synchronized with the optimizer - # just use another way for the movement - if not self.optim_sync_flag: - assert device.type == 'cuda', "each chunk should first be moved to CUDA" - self.__paired_shard_move() - self.optim_sync_flag = True - return - - if device.type == 'cuda': - assert device == get_current_device(), "can't move chunk to another device" - - if self.cuda_shard: - return - - self.cuda_shard = self.cpu_shard.to(get_current_device()) - - if not self.pin_memory: - self.cpu_shard = None - elif device.type == 'cpu': - if self.cuda_shard is None: - return - - if self.pin_memory: - if force_copy or not self.cpu_vis_flag: - self.cpu_shard.copy_(self.cuda_shard) - # if cpu_shard has been visited - # copy operation is not need - else: - self.cpu_shard = self.cuda_shard.cpu() - self.cpu_vis_flag = True - self.cuda_shard = None - else: - raise NotImplementedError - - def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. - """ - # sanity check - assert self.chunk_temp is None - - if not self.is_gathered: - self.__gather() - self.__update_tensors_ptr() - - def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. - """ - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - self.__scatter() - - def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. - """ - # sanity check - assert self.is_gathered - - if self.pg_size == 1: - # tricky code here - # just move chunk_total to cuda_shard - # the communication is not necessary - self.__scatter() - elif self.keep_gathered: - # we use all-reduce here - dist.all_reduce(self.chunk_total, group=self.torch_pg) - else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) - - input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) - - free_storage(self.chunk_total) - self.is_gathered = False - self.__update_tensors_state(TensorState.HOLD) - - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: - """ - Make a transition of the tensor into the next state. - - Args: - tensor (torch.Tensor): a torch Tensor object. - tensor_state (TensorState): the target state for transition. - """ - - # As the gradient hook can be triggered either before or after post-backward - # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce - # or compute -> ready_for_reduce -> hold_after_bwd - # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd - # this function only apply valid state transformation - # invalid calls will be ignored and nothing changes - if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: - return - self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: - """ - Copy data slice to the memory space indexed by the input tensor in the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data_slice (torch.Tensor): the tensor to be copied to the chunk - """ - # sanity check - assert self.is_gathered - - tensor_info = self.tensors_info[tensor] - self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ - if self.keep_gathered: - return self.utilized_size - else: - return self.valid_end - - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. - """ - if self.paired_chunk is None and friend_chunk.paired_chunk is None: - self.paired_chunk = friend_chunk - friend_chunk.paired_chunk = self - else: - assert self.paired_chunk is friend_chunk - assert friend_chunk.paired_chunk is self - - def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. - """ - # sanity check - assert self.paired_chunk is not None - - friend_chunk = self.paired_chunk - if self.is_gathered is True: - assert friend_chunk.is_gathered is True - self.chunk_total.copy_(friend_chunk.chunk_total) - self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': - self.cuda_shard.copy_(friend_chunk.cuda_shard) - self.optim_sync_flag = True - self.cpu_vis_flag = False - else: - # optim_sync_flag is set to False - # see shard_move function for more details - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' - self.optim_sync_flag = False - self.cpu_vis_flag = False - - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) - - def __gather(self): - if not self.is_gathered: - # sanity check - assert self.cuda_shard is not None - - alloc_storage(self.chunk_total) - gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) - - self.cuda_shard = None - self.is_gathered = True - - def __scatter(self): - if self.keep_gathered: - return - - if self.is_gathered: - # sanity check - assert self.cuda_shard is None - - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device) - - self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end]) - - free_storage(self.chunk_total) - self.is_gathered = False - - def __paired_shard_move(self): - assert self.paired_chunk is not None, "chunks should be paired before training" - optim_chunk = self.paired_chunk - assert self.chunk_size == optim_chunk.chunk_size - - # only be called when optimizer state is in CPU memory - # the grad and param should be in the same device - assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) - # avoid to transform FP32 in CPU - self.cuda_shard = temp.to(self.dtype) - - if not self.pin_memory: - self.cpu_shard = None - - def __update_tensors_ptr(self) -> None: - # sanity check - assert self.is_gathered - assert type(self.chunk_total) == torch.Tensor - - for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): - self.tensors_state_monitor[tensor_info.state] -= 1 - tensor_info.state = next_state - self.tensors_state_monitor[tensor_info.state] += 1 - - def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): - for tensor_info in self.tensors_info.values(): - if prev_state is None or tensor_info.state == prev_state: - self.__update_one_tensor_info(tensor_info, next_state) - - def __hash__(self) -> int: - return hash(id(self)) - - def __eq__(self, __o: object) -> bool: - return self is __o - - def __repr__(self, detailed: bool = True): - output = [ - "Chunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, - self.pg_size), - "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( - self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) - ] - - def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, - tensor.device)) - - if self.chunk_temp is not None: - output.append("\tchunk temp:\n") - print_tensor(tensor=self.chunk_temp, prefix='\t\t') - - if self.chunk_total is not None and self.chunk_total.storage().size() > 0: - output.append("\tchunk total:\n") - print_tensor(tensor=self.chunk_total, prefix='\t\t') - - if self.cuda_shard is not None: - output.append("\tcuda shard:\n") - print_tensor(tensor=self.cuda_shard, prefix='\t\t') - - if self.cpu_shard is not None: - output.append("\tcpu shard:\n") - print_tensor(tensor=self.cpu_shard, prefix='\t\t') - - memory_info = self.memory_usage - output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) - - if detailed: - output.append("\ttensor state monitor:\n") - for st in TensorState: - output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) - - return ''.join(output) +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, + TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + _total_number = 0 + + def __init__(self, + chunk_size: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False) -> None: + """ + Chunk: A container owning a piece of contiguous memory space for tensors + Here we use all-gather operation to gather the whole chunk. + Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. + It is designed to make the full use of communication and PCIE bandwidth. + + Args: + chunk_size (int): the number of elements in the chunk + process_group (ColoProcessGroup): the process group of this chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. + The default value is None, which is the current GPU + cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU. + keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory + pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory + """ + self.count_id = Chunk._total_number + Chunk._total_number += 1 + + self.chunk_size = chunk_size + self.utilized_size = 0 + + self.torch_pg = process_group.dp_process_group() + self.pg_size = dist.get_world_size(self.torch_pg) + self.pg_rank = dist.get_rank(self.torch_pg) + + # the chunk size should be divisible by the dp degree + if not keep_gathered: + assert chunk_size % self.pg_size == 0 + self.shard_size = chunk_size // self.pg_size + self.shard_begin = self.shard_size * self.pg_rank + self.shard_end = self.shard_begin + self.shard_size + self.valid_end = self.shard_size + + self.dtype = dtype + device = init_device or get_current_device() + + # chunk_temp is a global chunk, which only exists during building the chunks. + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + + # cuda local chunk, which is sharded on GPUs + self.cuda_shard = None + # cpu local chunk, which is sharded on CPUs + self.cpu_shard = None + # is the chunks gathers, which means chunks are duplicated on each process, + # and we should use the cuda_global_chunk. + self.is_gathered = True + + # configure the init device of the shard + # no-offload default: fp16, fp32 -> CUDA + # offload default: fp16, fp32 -> CPU + self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + + self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() + self.shard_mem = self.chunk_mem // self.pg_size + + # each tensor is associated with a TensorInfo to track its meta info + # (state, offset, end) + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of tensors in the chunk + self.num_tensors = 0 + + # Record the number of tensors in different states + self.tensor_state_cnter: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensor_state_cnter[state] = 0 + + # If a chunk is kept gathered, + # they are treated the same as that of the parameters in DDP during training. + self.keep_gathered = keep_gathered + if self.keep_gathered: + pin_memory = False # since this chunk is gathered, it doesn't need to pin + + # if pin_memory is True, we allocate a piece of CPU pin-memory + # for it all the time + self.pin_memory = pin_memory + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + self.paired_chunk = None + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + # if the cpu_shard has been visited during the training step, the flag is True + self.cpu_vis_flag = False + + # whether to record l2 norm for the gradient clipping calculation + self.l2_norm_flag = False + self.l2_norm = None + + @property + def memory_usage(self) -> Dict[str, int]: + cuda_memory = 0 + cpu_memory = 0 + + if self.chunk_temp is not None: + # this chunk is not closed + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_mem + else: + cpu_memory += self.chunk_mem + else: + if self.is_gathered: + cuda_memory += self.chunk_mem + if self.cuda_shard is not None: + cuda_memory += self.shard_mem + if self.cpu_shard is not None: + cpu_memory += self.shard_mem + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def device_type(self) -> str: + if self.chunk_temp is not None: + return self.chunk_temp.device.type + else: + if self.is_gathered: + return 'cuda' + elif self.cuda_shard is not None: + return 'cuda' + else: + return 'cpu' + + @property + def payload(self) -> torch.Tensor: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.cuda_global_chunk + elif self.cuda_shard is not None: + return self.cuda_shard + else: + return self.cpu_shard + + @property + def payload_mem(self) -> int: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_mem + else: + return self.shard_mem + + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: + if self.keep_gathered: + return False + else: + return self.tensor_state_cnter[TensorState.HOLD] + \ + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: + """Check if the chunk has inf or nan values on CUDA. + """ + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + + def set_l2_norm(self) -> None: + """Record l2 norm of this chunks on CUDA. + """ + assert self.l2_norm is None, "you are calculating the l2 norm twice" + if self.is_gathered: + valid_tensor = self.cuda_global_chunk[:self.utilized_size] + else: + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + chunk_l2_norm = valid_tensor.data.float().norm(2) + self.l2_norm = chunk_l2_norm.item()**2 + + def append_tensor(self, tensor: torch.Tensor): + """Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + # sanity check + assert self.chunk_temp is not None + assert tensor.dtype == self.dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.tensor_state_cnter[tensor_state] += 1 + self.utilized_size = new_utilized_size + + def close_chunk(self): + """Close the chunk. Any tensor can't be appended to a closed chunk later. + """ + # sanity check + assert self.chunk_temp is not None + + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + + if self.chunk_temp.device.type == 'cpu': + self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.__update_tensors_ptr() + else: + self.cuda_global_chunk = self.chunk_temp + self.chunk_temp = None + + self.__scatter() + # gathered chunk never have shard attribute + if self.keep_gathered: + return + + if self.pin_memory or self.shard_device.type == 'cpu': + self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) + self.cpu_shard.copy_(self.cuda_shard) + self.cpu_vis_flag = True # cpu_shard has been visited + + if self.shard_device.type == 'cpu': + self.cuda_shard = None + + def shard_move(self, device: torch.device, force_copy: bool = False): + """Move the shard tensor in the chunk. + + Args: + device: the device to which the shard will move + force_copy: if True, copy function is called mandatorily + """ + # sanity check + assert not self.is_gathered + # when the current chunk is not synchronized with the optimizer + # just use another way for the movement + if not self.optim_sync_flag: + assert device.type == 'cuda', "each chunk should first be moved to CUDA" + self.__paired_shard_move() + self.optim_sync_flag = True + return + + if device.type == 'cuda': + assert device == get_current_device(), "can't move chunk to another device" + + if self.cuda_shard: + return + + self.cuda_shard = self.cpu_shard.to(get_current_device()) + + if not self.pin_memory: + self.cpu_shard = None + elif device.type == 'cpu': + if self.cuda_shard is None: + return + + if self.pin_memory: + if force_copy or not self.cpu_vis_flag: + self.cpu_shard.copy_(self.cuda_shard) + # if cpu_shard has been visited + # copy operation is not need + else: + self.cpu_shard = self.cuda_shard.cpu() + self.cpu_vis_flag = True + self.cuda_shard = None + else: + raise NotImplementedError + + def access_chunk(self): + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if not self.is_gathered: + self.__gather() + self.__update_tensors_ptr() + + def release_chunk(self): + """Release the usable chunk. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + self.__scatter() + + def reduce(self): + """Reduce scatter all the gradients. It's an operation done in CUDA. + """ + # sanity check + assert self.is_gathered + + if self.pg_size == 1: + # tricky code here + # just move cuda_global_chunk to cuda_shard + # the communication is not necessary + self.__scatter() + elif self.keep_gathered: + # we use all-reduce here + dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) + else: + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + + input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + self.__update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + return + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def get_valid_length(self) -> int: + """Get the valid length of the chunk's payload. + """ + if self.keep_gathered: + return self.utilized_size + else: + return self.valid_end + + def init_pair(self, friend_chunk: 'Chunk') -> None: + """Initialize the paired chunk. + """ + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self + else: + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self + + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + + friend_chunk = self.paired_chunk + if self.is_gathered is True: + assert friend_chunk.is_gathered is True + self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) + self.optim_sync_flag = True + elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + self.cuda_shard.copy_(friend_chunk.cuda_shard) + self.optim_sync_flag = True + self.cpu_vis_flag = False + else: + # optim_sync_flag is set to False + # see shard_move function for more details + assert friend_chunk.device_type == 'cpu' + assert self.device_type == 'cpu' + self.optim_sync_flag = False + self.cpu_vis_flag = False + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + def __gather(self): + if not self.is_gathered: + # sanity check + assert self.cuda_shard is not None + + alloc_storage(self.cuda_global_chunk) + gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + + self.cuda_shard = None + self.is_gathered = True + + def __scatter(self): + if self.keep_gathered: + return + + if self.is_gathered: + # sanity check + assert self.cuda_shard is None + + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) + + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + + free_storage(self.cuda_global_chunk) + self.is_gathered = False + + def __paired_shard_move(self): + assert self.paired_chunk is not None, "chunks should be paired before training" + optim_chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.cuda_shard is None + temp = optim_chunk.cpu_shard.to(get_current_device()) + # avoid to transform FP32 in CPU + self.cuda_shard = temp.to(self.dtype) + + if not self.pin_memory: + self.cpu_shard = None + + def __update_tensors_ptr(self) -> None: + # sanity check + assert self.is_gathered + assert type(self.cuda_global_chunk) == torch.Tensor + + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensor_state_cnter[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensor_state_cnter[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def __repr__(self, detailed: bool = True): + output = [ + "Chunk Information:\n", + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, + self.pg_size), + "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + ] + + def print_tensor(tensor, prefix=''): + output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, + tensor.device)) + + if self.chunk_temp is not None: + output.append("\tchunk temp:\n") + print_tensor(tensor=self.chunk_temp, prefix='\t\t') + + if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: + output.append("\tchunk total:\n") + print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + + if self.cuda_shard is not None: + output.append("\tcuda shard:\n") + print_tensor(tensor=self.cuda_shard, prefix='\t\t') + + if self.cpu_shard is not None: + output.append("\tcpu shard:\n") + print_tensor(tensor=self.cpu_shard, prefix='\t\t') + + memory_info = self.memory_usage + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + + if detailed: + output.append("\ttensor state monitor:\n") + for st in TensorState: + output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) + + return ''.join(output) diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 4a2474a63..07fb6c48b 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -1,230 +1,239 @@ -import torch -from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable -from collections import deque - -from colossalai.utils import get_current_device -from colossalai.tensor import ColoTensor -from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk - - -class ChunkManager: - """ - A manager class to manipulate the tensors in chunks. - - Args: - chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. - init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. - """ - - def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: - - self.device = init_device or get_current_device() - self.size_config: Dict[int, int] = dict() - self.kwargs_config = chunk_configuration - for k, v in self.kwargs_config.items(): - self.size_config[k] = v.pop('chunk_size') - v['init_device'] = self.device - - self.chunk_groups: Dict[str, Deque] = dict() - self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() - self.accessed_chunks: Set[Chunk] = set() - self.accessed_mem: int = 0 - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: - """Append a tensor to a chunk. - - Args: - tensor: the tensor appended to the chunk - group_type: the data type of the group - config_key: the key of the group's name, usually the size of the dp world - pin_memory: whether the chunk is pinned in the cpu memory - """ - assert tensor not in self.tensor_chunk_map - assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" - assert config_key in self.size_config - - chunk_size = self.size_config[config_key] - chunk_kwargs = self.kwargs_config[config_key] - group_name = "{}_{}".format(group_type, config_key) - chunk_group = self.__get_chunk_group(group_name) - - try: - # append the tensor to the last chunk - chunk_group[-1].append_tensor(tensor) - except (IndexError, ChunkFullError): - # the except statement will be triggered when there is no chunk or - # the last chunk in the chunk group is full - # this will create a new chunk and allocate this chunk to its corresponding process - if chunk_group: - # the chunk group is not empty - # close the last chunk - self.__close_one_chunk(chunk_group[-1]) - - if tensor.numel() > chunk_size: - chunk_size = tensor.numel() - chunk = Chunk( - chunk_size=chunk_size, - process_group=tensor.process_group, - dtype=tensor.dtype, - pin_memory=pin_memory, - **chunk_kwargs, - ) - - chunk_group.append(chunk) - chunk.append_tensor(tensor) - self.__add_memory_usage(chunk.memory_usage) - - self.tensor_chunk_map[tensor] = chunk_group[-1] - - def close_all_groups(self): - """Close all the chunks of all groups. - """ - for group_name in self.chunk_groups: - self.__close_one_chunk(self.chunk_groups[group_name][-1]) - - def access_chunk(self, chunk: Chunk) -> None: - """Make the chunk can be used for calculation. - """ - if chunk in self.accessed_chunks: - return - self.__sub_memroy_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': - chunk.shard_move(get_current_device()) - self.__add_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - - def release_chunk(self, chunk: Chunk) -> None: - """Scatter the chunk in CUDA. - """ - if chunk not in self.accessed_chunks: - return - if chunk.can_release: - self.__sub_memroy_usage(chunk.memory_usage) - self.__sub_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - - def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: - """Move the shard of the chunk to the target device. - """ - if not chunk.can_move or chunk.device_type == device.type: - return - self.__sub_memroy_usage(chunk.memory_usage) - chunk.shard_move(device, force_copy) - self.__add_memory_usage(chunk.memory_usage) - - def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """Transit tensor state according to pre-defined state machine. - """ - chunk = self.tensor_chunk_map[tensor] - chunk.tensor_trans_state(tensor, state) - - def reduce_chunk(self, chunk: Chunk) -> bool: - """Reduce or all reduce the chunk. - """ - if not chunk.can_reduce: - return False - self.__sub_memroy_usage(chunk.memory_usage) - chunk.reduce() - self.__sub_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - return True - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: - """ - Copy data to the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data (torch.Tensor): the tensor to be copied to the chunk - """ - chunk = self.tensor_chunk_map[tensor] - chunk.copy_tensor_to_chunk_slice(tensor, data) - - def get_chunk(self, tensor: torch.Tensor) -> Chunk: - """ - Return the chunk owning the tensor. - - Args: - tensor (torch.Tensor): a torch tensor object - """ - return self.tensor_chunk_map[tensor] - - def get_cuda_movable_chunks(self) -> List[Chunk]: - """ - Get all chunks that can be moved. - """ - chunk_list = [] - for chunk in self.accessed_chunks: - if chunk.can_release: - chunk_list.append(chunk) - chunk_list.sort(key=lambda x: x.count_id) - return chunk_list - - def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: - """ - Get all chunks owning the input tensors. - - Args: - tensors (Iterable[torch.Tensor]): the tensors used to look for chunks - """ - chunks = [] - for tensor in tensors: - chunk = self.get_chunk(tensor) - if chunk not in chunks: - chunks.append(chunk) - return tuple(chunks) - - def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: - """Add extern static tensor to chunk manager. - Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. - They are "static", which means their shape, dtype, device never change. - Thus, their memory usage never changes. - - Args: - tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. - """ - assert tensor not in self.tensor_chunk_map - self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() - - def __repr__(self) -> str: - msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' - ] - for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') - for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) - - def __get_chunk_group(self, group_name: str) -> Deque: - """Register a chunk group. - """ - if group_name not in self.chunk_groups: - self.chunk_groups[group_name] = deque() - return self.chunk_groups[group_name] - - def __close_one_chunk(self, chunk: Chunk): - device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda - self.__sub_memroy_usage(chunk.memory_usage) - chunk.close_chunk(device) - self.__add_memory_usage(chunk.memory_usage) - - def __sub_memroy_usage(self, usage: Dict[str, int]): - for k, v in usage.items(): - self.total_mem[k] -= v - - def __add_memory_usage(self, usage: Dict[str, int]): - for k, v in usage.items(): - self.total_mem[k] += v - - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() - self.accessed_chunks.add(chunk) - self.accessed_mem += chunk.chunk_mem - - def __sub_accessed_chunk(self, chunk: Chunk): - chunk.release_chunk() - self.accessed_chunks.remove(chunk) - self.accessed_mem -= chunk.chunk_mem +from collections import deque +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple + +import torch + +from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState +from colossalai.tensor import ColoTensor +from colossalai.utils import get_current_device + + +class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ + + def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: + + self.device = init_device or get_current_device() + self.dp_degree_chunk_size_dict: Dict[int, int] = dict() + self.kwargs_config = chunk_configuration + for k, v in self.kwargs_config.items(): + self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') + v['init_device'] = self.device + + self.chunk_groups: Dict[str, Deque] = dict() + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() + self.accessed_chunks: Set[Chunk] = set() + self.accessed_mem: int = 0 + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + + def register_tensor(self, + tensor: ColoTensor, + group_type: str, + config_key: int, + cpu_offload: bool = False, + pin_memory: bool = False) -> None: + """ + Register a tensor to the chunk manager. + Then, the tensor should be accessed by `get_chunks`. + + Args: + tensor: the tensor appended to the chunk + group_type: the data type of the group. + config_key: the key of the group's name, the size of the dp world + cpu_offload: if True, the chunk will be closed on CPU + pin_memory: whether the chunk is pinned in the cpu memory + """ + assert tensor not in self.tensor_chunk_map + assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" + assert config_key in self.dp_degree_chunk_size_dict + + chunk_size = self.dp_degree_chunk_size_dict[config_key] + chunk_kwargs = self.kwargs_config[config_key] + group_name = "{}_{}".format(group_type, config_key) + chunk_group = self.__get_chunk_group(group_name) + + try: + # append the tensor to the last chunk + chunk_group[-1].append_tensor(tensor) + except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process + if chunk_group: + # the chunk group is not empty + # close the last chunk + self.__close_one_chunk(chunk_group[-1]) + + if tensor.numel() > chunk_size: + chunk_size = tensor.numel() + chunk = Chunk( + chunk_size=chunk_size, + process_group=tensor.process_group, + dtype=tensor.dtype, + cpu_shard_init=cpu_offload, + pin_memory=pin_memory, + **chunk_kwargs, + ) + + chunk_group.append(chunk) + chunk.append_tensor(tensor) + self.__add_memory_usage(chunk.memory_usage) + + self.tensor_chunk_map[tensor] = chunk_group[-1] + + def close_all_groups(self): + """Close all the chunks of all groups. + """ + for group_name in self.chunk_groups: + self.__close_one_chunk(self.chunk_groups[group_name][-1]) + + def access_chunk(self, chunk: Chunk) -> None: + """Make the chunk can be used for calculation. + """ + if chunk in self.accessed_chunks: + return + self.__sub_memroy_usage(chunk.memory_usage) + if chunk.device_type == 'cpu': + chunk.shard_move(get_current_device()) + self.__add_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def release_chunk(self, chunk: Chunk) -> None: + """Scatter the chunk in CUDA. + """ + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + """Move the shard of the chunk to the target device. + """ + if not chunk.can_move or chunk.device_type == device.type: + return + self.__sub_memroy_usage(chunk.memory_usage) + chunk.shard_move(device, force_copy) + self.__add_memory_usage(chunk.memory_usage) + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """Transit tensor state according to pre-defined state machine. + """ + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, chunk: Chunk) -> bool: + """Reduce or all reduce the chunk. + """ + if not chunk.can_reduce: + return False + self.__sub_memroy_usage(chunk.memory_usage) + chunk.reduce() + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + return True + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ + return self.tensor_chunk_map[tensor] + + def get_cuda_movable_chunks(self) -> List[Chunk]: + """ + Get all chunks that can be moved. + """ + chunk_list = [] + for chunk in self.accessed_chunks: + if chunk.can_release: + chunk_list.append(chunk) + chunk_list.sort(key=lambda x: x.count_id) + return chunk_list + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + def __repr__(self) -> str: + msg = [ + 'Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + ] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) + + def __get_chunk_group(self, group_name: str) -> Deque: + """Register a chunk group. + """ + if group_name not in self.chunk_groups: + self.chunk_groups[group_name] = deque() + return self.chunk_groups[group_name] + + def __close_one_chunk(self, chunk: Chunk): + self.__sub_memroy_usage(chunk.memory_usage) + chunk.close_chunk() + self.__add_memory_usage(chunk.memory_usage) + + def __sub_memroy_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] -= v + + def __add_memory_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] += v + + def __add_accessed_chunk(self, chunk: Chunk): + chunk.access_chunk() + self.accessed_chunks.add(chunk) + self.accessed_mem += chunk.chunk_mem + + def __sub_accessed_chunk(self, chunk: Chunk): + chunk.release_chunk() + self.accessed_chunks.remove(chunk) + self.accessed_mem -= chunk.chunk_mem diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index d7b5c7aa8..312d77f18 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -1,9 +1,10 @@ import math -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch.nn as nn +from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter @@ -12,7 +13,8 @@ def in_ddp(param: nn.Parameter) -> bool: def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: - """Filter those parameters whose size is too large from others. + """ + Filter those parameters whose size is too large (more than 3x standard deviations) from others. """ params_size = [p.numel() for p in model.parameters() if in_ddp(p)] params_size_arr = np.array(params_size) @@ -39,11 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: return left + acc -def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: - """Clasify each parameter by its size of DP group. +def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: + """classify_params_by_dp_degree + + Classify the parameters by their dp degree + + Args: + param_order (OrderedParamGenerator): the order of param be visied + + Returns: + Dict[int, List[ColoParameter]]: a dict contains the classification results. + The keys are dp_degrees and the values are parameters. """ params_dict: Dict[int, List[ColoParameter]] = dict() - for param in model.parameters(): + for param in param_order.generate(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if not in_ddp(param): continue @@ -62,24 +73,45 @@ def search_chunk_configuration( search_range_mb: float, search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: float = 32, - filter_exlarge_params: bool = True) -> Tuple[Dict, int]: + filter_exlarge_params: bool = True, + memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: + """search_chunk_configuration + + Args: + model (nn.Module): torch module + search_range_mb (float): searching range in mega byte. + search_interval_byte (int): searching interval in byte. + filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. + + Returns: + Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. + """ + + if memstas is not None: + param_order = memstas.param_order() + else: + # build the param visited order right now + param_order = OrderedParamGenerator() + for p in model.parameters(): + param_order.append(p) + search_range_byte = round(search_range_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) assert search_range_byte >= 0 - params_dict = clasify_params(model) + params_dict = classify_params_by_dp_degree(param_order) config_dict: Dict[int, Dict] = dict() size_dict: Dict[int, List[int]] = dict() - for key in params_dict: - params_list = params_dict[key] + for dp_degree in params_dict: + params_list = params_dict[dp_degree] size_list = [p.numel() for p in params_list] # let small parameters keep gathered in CUDA all the time total_size = sum(size_list) if total_size < min_chunk_size_byte: - config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) + config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True) else: - size_dict[key] = size_list + size_dict[dp_degree] = size_list if filter_exlarge_params: _filter_exlarge_params(model, size_dict) @@ -100,9 +132,9 @@ def search_chunk_configuration( min_chunk_waste = temp_waste best_chunk_size = chunk_size - for key in params_dict: - if key in config_dict: + for dp_degree in params_dict: + if dp_degree in config_dict: continue - config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) + config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) return config_dict, min_chunk_waste diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py index 9d87129db..e9a9f84e7 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/gemini/chunk/utils.py @@ -7,6 +7,7 @@ import torch.nn as nn from colossalai.gemini.chunk import ChunkManager from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration +from colossalai.gemini.memory_tracer import MemStats def init_chunk_manager(model: nn.Module, @@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module, total_size = sum(params_sizes) / 1024**2 dist.barrier() - begine = time() + begin = time() config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) dist.barrier() end = time() - span_s = end - begine + span_s = end - begin wasted_size /= 1024**2 if dist.get_rank() == 0: diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 6d6b7425c..08961b958 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -1,9 +1,13 @@ -import torch import functools -from .memory_tracer.memstats_collector import MemStatsCollectorV2 -from typing import List, Optional, Tuple from time import time +from typing import List, Optional, Tuple + +import torch + from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import MemStats + +from .memory_tracer import ChunkMemStatsCollector from .placement_policy import PlacementPolicyFactory @@ -21,13 +25,20 @@ class GeminiManager: If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. Note that 'auto' policy can only work well when no other processes use CUDA during your training. chunk_manager (ChunkManager): A ``ChunkManager`` instance. + memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ - def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: - assert placement_policy in PlacementPolicyFactory.get_polocy_names() + def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + + assert placement_policy in PlacementPolicyFactory.get_policy_names() + self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager - self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None + + self._premade_memstats_ = memstats is not None + self._memstats = memstats + self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, + self._memstats) if policy_cls.need_mem_stats else None self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -39,7 +50,20 @@ class GeminiManager: self._warmup = True self._comp_cuda_demand_time = 0 - def pre_iter(self): + def memstats(self): + """memstats + + get the memory statistics during training. + The stats could be collected by a runtime memory tracer, or collected by the GeminiManager. + Note, for the latter, you can not access the memstats before warmup iteration finishes. + """ + if self._premade_memstats_: + return self._memstats + else: + assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup." + return self._mem_stats_collector._memstats + + def pre_iter(self, *args): if self._mem_stats_collector and self._warmup: self._mem_stats_collector.start_collection() @@ -57,7 +81,7 @@ class GeminiManager: self._comp_cuda_demand_time = 0 def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: - """ Adjust the layout of statefuil tensor according to the information provided + """ Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE @@ -109,9 +133,9 @@ class GeminiManager: if self._mem_stats_collector: self._mem_stats_collector.sample_overall_data() - def sample_model_data(self): + def record_model_data_volume(self): if self._mem_stats_collector: - self._mem_stats_collector.sample_model_data() + self._mem_stats_collector.record_model_data_volume() @property def chunk_manager(self): diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index 21b3e17b9..02c9d5754 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -1,5 +1,11 @@ -from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor -from .memstats_collector import MemStatsCollector +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .static_memstats_collector import StaticMemStatsCollector # isort:skip -__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'] +__all__ = [ + 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', + 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' +] diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py new file mode 100644 index 000000000..1a5b6bf52 --- /dev/null +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -0,0 +1,36 @@ +from typing import Optional + +from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.memory_tracer import MemStats +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity + +from .memstats_collector import MemStatsCollector + + +class ChunkMemStatsCollector(MemStatsCollector): + + def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + """ + + Memory Statistic Collector for Chunks. + + Args: + chunk_manager (ChunkManager): the chunk manager. + memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None. + """ + super().__init__(memstats) + self._chunk_manager = chunk_manager + + # override + def record_model_data_volume(self) -> None: + """ + record model data volumn on cuda and cpu. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_mem = self._chunk_manager.total_mem['cuda'] + self._memstats.record_max_cuda_model_data(cuda_mem) + + @property + def cuda_margin_mem(self) -> float: + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/gemini/memory_tracer/memory_monitor.py b/colossalai/gemini/memory_tracer/memory_monitor.py index 05d03d278..f8d99dbce 100644 --- a/colossalai/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/gemini/memory_tracer/memory_monitor.py @@ -1,142 +1,147 @@ -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor -from time import sleep, time -import json - -import torch - -from colossalai.utils import colo_device_memory_used -from colossalai.utils import get_current_device - - -class MemoryMonitor: - """Base class for all types of memory monitor. - All monitors should have a list called `time_stamps` and a list called `mem_stats`. - """ - - def __init__(self): - self.time_stamps = [] - self.mem_stats = [] - - def __len__(self): - return len(self.mem_stats) - - @abstractmethod - def start(self): - pass - - @abstractmethod - def finish(self): - pass - - def state_dict(self): - return { - "time_stamps": self.time_stamps, - "mem_stats": self.mem_stats, - } - - def save(self, filename): - with open(filename, "w") as f: - json.dump(self.state_dict(), f) - - def clear(self): - self.mem_stats.clear() - self.time_stamps.clear() - - -class AsyncMemoryMonitor(MemoryMonitor): - """ - An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU - at interval of `1/(10**power)` sec. - - The idea comes from Runtime Memory Tracer of PatrickStar - `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ - - Usage:: - - async_mem_monitor = AsyncMemoryMonitor() - input = torch.randn(2, 20).cuda() - OP1 = torch.nn.Linear(20, 30).cuda() - OP2 = torch.nn.Linear(30, 40).cuda() - - async_mem_monitor.start() - output = OP1(input) - async_mem_monitor.finish() - async_mem_monitor.start() - output = OP2(output) - async_mem_monitor.finish() - async_mem_monitor.save('log.pkl') - - Args: - power (int, optional): the power of time interva. Defaults to 10. - - .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: - https://arxiv.org/abs/2108.05818 - """ - - def __init__(self, power: int = 10): - super().__init__() - self.keep_measuring = False - - current_device = get_current_device() - - def _set_cuda_device(): - torch.cuda.set_device(current_device) - - self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) - self.monitor_thread = None - self.interval = 1 / (10**power) - - def set_interval(self, power: int): - self.clear() - self.interval = 1 / (10**power) - - def is_measuring(self): - return self.keep_measuring - - def start(self): - self.keep_measuring = True - self.monitor_thread = self.executor.submit(self._measure_usage) - - def finish(self): - if self.keep_measuring is False: - return 0 - - self.keep_measuring = False - max_usage = self.monitor_thread.result() - - self.monitor_thread = None - self.time_stamps.append(time()) - self.mem_stats.append(max_usage) - return max_usage - - def _measure_usage(self): - max_usage = 0 - while self.keep_measuring: - max_usage = max( - max_usage, - colo_device_memory_used(get_current_device()), - ) - sleep(self.interval) - return max_usage - - -class SyncCudaMemoryMonitor(MemoryMonitor): - """ - A synchronized cuda memory monitor. - It only record the maximum allocated cuda memory from start point to finish point. - """ - - def __init__(self, power: int = 10): - super().__init__() - - def start(self): - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - - def finish(self): - torch.cuda.synchronize() - self.time_stamps.append(time()) - max_usage = torch.cuda.max_memory_allocated() - self.mem_stats.append(max_usage) - return max_usage +import json +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time + +import torch + +from colossalai.utils import colo_device_memory_used, get_current_device + + +class MemoryMonitor: + """Base class for all types of memory monitor. + All monitors should have a list called `time_stamps` and a list called `mem_stats`. + """ + + def __init__(self): + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + @abstractmethod + def start(self): + pass + + @abstractmethod + def finish(self): + pass + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "w") as f: + json.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() + + +class AsyncMemoryMonitor(MemoryMonitor): + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of `1/(10**power)` sec. + + The idea comes from Runtime Memory Tracer of PatrickStar + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + Usage:: + + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, power: int = 10): + super().__init__() + self.keep_measuring = False + + current_device = get_current_device() + + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) + self.monitor_thread = None + self.interval = 1 / (10**power) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + + self.keep_measuring = False + max_usage = self.monitor_thread.result() + + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + colo_device_memory_used(get_current_device()), + ) + sleep(self.interval) + return max_usage + + +class SyncCudaMemoryMonitor(MemoryMonitor): + """ + A synchronized cuda memory monitor. + It only record the maximum allocated cuda memory from start point to finish point. + """ + + def __init__(self, power: int = 10): + super().__init__() + + def start(self): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + def finish(self) -> int: + """ + return max gpu memory used since latest `start()`. + + Returns: + int: max GPU memory + """ + torch.cuda.synchronize() + self.time_stamps.append(time()) + max_usage = torch.cuda.max_memory_allocated() + self.mem_stats.append(max_usage) + return max_usage diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py new file mode 100644 index 000000000..84fa00fb9 --- /dev/null +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, List, Optional + +import torch + +from colossalai.gemini.memory_tracer import OrderedParamGenerator + + +class MemStats(object): + + def __init__(self) -> None: + """ + Store the non model data statistics used for Gemini and ZeroOptimizer. + """ + # (preop_step, List[param]) + self._step_param_dict = dict() + # (param, List[preop_step]) + self._param_step_dict = dict() + # (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1) + self._step_nmd_dict = dict() + self._param_runtime_order = OrderedParamGenerator() + + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._max_overall_cuda = 0 + self._prev_md_cuda = -1 + + # old version + self._model_data_cuda_list = [] + self._model_data_cpu_list = [] + + self._overall_cuda_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cuda_list = [] + self._non_model_data_cpu_list = [] + + def calc_max_cuda_non_model_data(self): + if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: + max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda + self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data + # compatibility of the old version. + self._non_model_data_cuda_list.append(max_cuda_non_model_data) + + def record_max_cuda_model_data(self, val): + self._prev_md_cuda = val + + def record_max_cuda_overall_data(self, val): + self._prev_overall_cuda = val + self._max_overall_cuda = max(self._max_overall_cuda, val) + + @property + def max_overall_cuda(self): + return self._max_overall_cuda + + def increase_preop_step(self, param_list: List[torch.nn.Parameter]): + """ + the time step is increased. param list is used between current and the next + time step. + + Args: + param_list (List[torch.nn.Parameter]): a list of torch paramters. + """ + for p in param_list: + if p not in self._param_step_dict: + self._param_step_dict[p] = [self._preop_step] + else: + self._param_step_dict[p].append(self._preop_step) + self._param_runtime_order.append(p) + self._step_param_dict[self._preop_step] = param_list + self._preop_step += 1 + + def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]: + """param_used_step + get the timestep list using the param + + Args: + param (torch.nn.Parameter): a torch param + + Returns: + Optional[List[int]]: a list of int indicates the time step of preop hook. + """ + if param not in self._param_step_dict: + return None + else: + return self._param_step_dict[param] + + def param_order(self): + if self._param_runtime_order.is_empty(): + raise RuntimeError + else: + return self._param_runtime_order + + def non_model_data_list(self, device_type: str) -> List[int]: + if device_type == 'cuda': + return self._non_model_data_cuda_list + elif device_type == 'cpu': + return self._non_model_data_cpu_list + else: + raise TypeError + + def max_non_model_data(self, device_type: str) -> float: + if device_type == 'cuda': + return max(self._non_model_data_cuda_list) + elif device_type == 'cpu': + return max(self._non_model_data_cpu_list) + else: + raise TypeError + + def clear(self): + self._model_data_cuda_list = [] + self._overall_cuda_list = [] + + self._model_data_cpu_list = [] + self._overall_cpu_list = [] + + self._non_model_data_cpu_list = [] + self._non_model_data_cuda_list = [] + + self._param_runtime_order.clear() + self._step_param_dict.clear() + self._param_step_dict.clear() + self._step_nmd_dict.clear() + self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._prev_md_cuda = -1 diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 4366956fe..d939da6eb 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -1,18 +1,19 @@ -from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity -from colossalai.utils import get_current_device -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.chunk import ChunkManager +import time +from typing import List, Optional import torch -import time -from typing import List + +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.utils.memory import colo_device_memory_used + +from .memory_stats import MemStats class MemStatsCollector: """ A Memory statistic collector. - It works in two phases. + It works in two phases. Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU. The first iteration of DNN training. Phase 2. Runtime Phase: use the read-only collected stats @@ -21,48 +22,22 @@ class MemStatsCollector: It has a Sampling counter which is reset after DNN training iteration. """ - def __init__(self) -> None: + def __init__(self, memstats: Optional[MemStats] = None) -> None: self._mem_monitor = SyncCudaMemoryMonitor() - self._model_data_cuda_list = [] - self._overall_cuda_list = [] - - self._model_data_cpu_list = [] - self._overall_cpu_list = [] - - self._non_model_data_cuda_list = [] - self._non_model_data_cpu_list = [] self._sampling_time = [] self._start_flag = False self._step_idx = 0 self._step_total = 0 - - def overall_mem_stats(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._overall_cuda_list - elif device_type == 'cpu': - return self._overall_cpu_list + if memstats is not None: + self.use_outside_memstats = True + self._memstats = memstats else: - raise TypeError - - def model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._model_data_cuda_list - elif device_type == 'cpu': - return self._model_data_cpu_list - else: - raise TypeError - - def non_model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': - return self._non_model_data_cuda_list - elif device_type == 'cpu': - return self._non_model_data_cpu_list - else: - raise TypeError + self.use_outside_memstats = False + self._memstats = MemStats() def next_period_non_model_data_usage(self, device_type: str) -> int: - """Get max non model data memory usage of current sampling period + """Maximum non model data memory usage during the next Op run Args: device_type (str): device type, can be 'cpu' or 'cuda'. @@ -72,7 +47,10 @@ class MemStatsCollector: """ assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' - next_non_model_data = self.non_model_data_list(device_type)[self._step_idx] + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + f"step total {self._step_total}" + next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -86,67 +64,37 @@ class MemStatsCollector: def finish_collection(self): self.sample_overall_data() - self._step_total = len(self._sampling_time) + # self._step_total = len(self._sampling_time) + self._step_total = len(self._memstats.non_model_data_list('cuda')) self._start_flag = False - self._mem_monitor.finish() + print(f'finish_collection {self._step_total}') - def sample_model_data(self) -> None: - """Sampling model data statistics. + # deprecated + def record_model_data_volume(self) -> None: """ - if self._start_flag: + Sampling model data statistics. + """ + if self._start_flag and not self.use_outside_memstats: + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] - cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] - self._model_data_cuda_list.append(cuda_mem) - self._model_data_cpu_list.append(cpu_mem) + self._memstats.record_max_cuda_model_data(cuda_mem) def sample_overall_data(self) -> None: - """Sampling non model data statistics. """ - if self._start_flag: - # overall data recording is after model data recording - if len(self._model_data_cuda_list) == 0: - return + Sampling overall and non model data cuda memory statistics. + """ + if self._start_flag and not self.use_outside_memstats: + cuda_overall = self._mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_overall) + self._memstats.calc_max_cuda_non_model_data() - self._overall_cuda_list.append(self._mem_monitor.finish()) - self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu'))) - - assert len(self._model_data_cuda_list) == len(self._overall_cuda_list) - - self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) - self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) - self._sampling_time.append(time.time()) self._mem_monitor.start() + if self._start_flag: + self._sampling_time.append(time.time()) + def clear(self) -> None: - self._model_data_cuda_list = [] - self._overall_cuda_list = [] - - self._model_data_cpu_list = [] - self._overall_cpu_list = [] - - self._non_model_data_cpu_list = [] - self._non_model_data_cuda_list = [] - + self._memstats.clear() self._start_flag = False self._step_idx = 0 self._step_total = 0 - - -class MemStatsCollectorV2(MemStatsCollector): - - def __init__(self, chunk_manager: ChunkManager) -> None: - super().__init__() - self._chunk_manager = chunk_manager - - def sample_model_data(self) -> None: - """Sampling model data statistics. - """ - if self._start_flag: - cuda_mem = self._chunk_manager.total_mem['cuda'] - cpu_mem = self._chunk_manager.total_mem['cpu'] - self._model_data_cuda_list.append(cuda_mem) - self._model_data_cpu_list.append(cpu_mem) - - @property - def cuda_margin_mem(self) -> float: - return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py new file mode 100644 index 000000000..638c0533c --- /dev/null +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -0,0 +1,42 @@ +from abc import ABC + +import torch + + +class ParamGenerator(ABC): + + def append(self, param: torch.nn.Parameter): + pass + + def generate(self): + pass + + def clear(self): + pass + + +class OrderedParamGenerator(ParamGenerator): + """OrderedParamGenerator + + Contain the order of parameters visited during runtime. + """ + + def __init__(self) -> None: + self.param_visited_order = [] + + def append(self, param: torch.nn.Parameter): + self.param_visited_order.append(param) + + def generate(self): + visited_set = set() + for p in self.param_visited_order: + if p not in visited_set: + yield p + visited_set.add(p) + del visited_set + + def is_empty(self): + return len(self.param_visited_order) == 0 + + def clear(self): + self.param_visited_order = [] diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py new file mode 100644 index 000000000..a643751da --- /dev/null +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -0,0 +1,99 @@ +import torch.nn + +from colossalai.gemini.memory_tracer import MemStats +from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook +from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.tensor.param_op_hook import ColoParamOpHookManager + +__all__ = ['RuntimeMemTracer'] + + +class RuntimeMemTracer(): + """RuntimeMemTracer for the module training using ColoParameter. + + Trace non-model memory usage during fwd+bwd process. + It is obtained by using a tensor with the same shape as the training process as the inputs + and running an single fwd+bwd to trace the statistics. + + NOTE() + 1. The premise to use this tracer is that the target DNN execute the same operations at each iterations, + 2. Module buffers are viewed as non-model data. + """ + + def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): + super().__init__() + self.module = module + self.dtype = dtype + self._gradstat = GradMemStats() + self._memstats = MemStats() + self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat) + self.grad_hook = GradMemTracerHook(self._gradstat) + self.cpu_param_data_dict = {} + + for p in module.parameters(): + p.data = p.data.to(dtype) + + self._cast_buffers_to_cuda_dtype() + + def parameters_in_runtime_order(self): + return self._memstats._param_runtime_order.generate() + + def memstats(self): + return self._memstats + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _backup_params(self): + """ + The function is called before forward. Backup model params on cpu. + """ + for p in self.module.parameters(): + self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu") + self.cpu_param_data_dict[p].copy_(p.data) + + def _restore_params(self): + """ + This function is called after backward. Restore model params. + """ + for p in self.module.parameters(): + p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad) + p.data.copy_(self.cpu_param_data_dict[p]) + self.cpu_param_data_dict.clear() + + def _pre_forward(self): + self._clear_cuda_mem_info() + self._backup_params() + self.grad_hook.register_grad_hook(self.module) + self.param_op_hook.mem_monitor.start() + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) + self.module.zero_grad(set_to_none=True) + self._pre_forward() + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + return outputs + + def backward(self, loss): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def _post_backward(self): + cuda_volume = self.param_op_hook.mem_monitor.finish() + self._memstats.record_max_cuda_overall_data(cuda_volume) + # calc the last Op non model data + self._memstats.calc_max_cuda_non_model_data() + self.grad_hook.remove_grad_hook() + self._restore_params() + + def _clear_cuda_mem_info(self): + self._memstats.clear() + self._gradstat.clear() + + def _cast_buffers_to_cuda_dtype(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.data.to(self.dtype) diff --git a/colossalai/gemini/memory_tracer/static_memstats_collector.py b/colossalai/gemini/memory_tracer/static_memstats_collector.py new file mode 100644 index 000000000..3209881e1 --- /dev/null +++ b/colossalai/gemini/memory_tracer/static_memstats_collector.py @@ -0,0 +1,105 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.fx import symbolic_trace + +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta +from colossalai.gemini.chunk import ChunkManager + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +from .chunk_memstats_collector import ChunkMemStatsCollector + + +class ModuleInfos: + + def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, + parent_module: torch.nn.Module): + self.module = module + self.module_name = module_name + self.module_full_name = module_full_name + self.parent_module = parent_module + + +class StaticMemStatsCollector(ChunkMemStatsCollector): + """ + A Static Memory statistic collector. + """ + + def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: + super().__init__(chunk_manager) + self.module = module + self.module_info_list = [] + + def init_mem_stats(self, *inputs): + + self.register_opnodes_recursively(self.module) + self.refactor_module() + + self.module = self.module.cpu() + self.module.train() + + data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + gm = symbolic_trace(self.module) + interp = MetaInfoProp(gm) + interp.propagate(*data) + + total_mem = 0 + for inp in inputs: + total_mem += inp.numel() * inp.element_size() + last_node = None + module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list] + for node in gm.graph.nodes: + total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem) + last_node = node + self._non_model_data_cuda_list.append(total_mem) + self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:] + + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = last_node.meta["fwd_mem_out"] + for node in gm.graph.nodes.__reversed__(): + cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node) + cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] + if node.op == "call_module": + if node.name.endswith("_0") and node.name[:-2] in module_name_list: + self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd) + total_mem = total_mem - cur_module_mem_fwd + cur_module_mem_fwd = 0 + cur_module_mem_bwd = 0 + grad_module_out = node.meta["bwd_mem_out"] + + self._step_total = len(self._non_model_data_cuda_list) + self.recover_module() + + def refactor_module(self): + for modInfo in self.module_info_list: + temp_node = nn.Sequential(nn.ReLU(), modInfo.module) + modInfo.parent_module.__setattr__(modInfo.module_name, temp_node) + + def recover_module(self): + for modInfo in self.module_info_list: + modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) + + def register_opnodes_recursively(self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None): + + assert isinstance(module, torch.nn.Module) + + for child_name, child in module.named_children(): + self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module)) diff --git a/colossalai/gemini/memory_tracer/model_data_memtracer.py b/colossalai/gemini/memory_tracer/utils.py similarity index 51% rename from colossalai/gemini/memory_tracer/model_data_memtracer.py rename to colossalai/gemini/memory_tracer/utils.py index 98228892d..6962c0581 100644 --- a/colossalai/gemini/memory_tracer/model_data_memtracer.py +++ b/colossalai/gemini/memory_tracer/utils.py @@ -1,7 +1,6 @@ -from colossalai.context.singleton_meta import SingletonMeta +from typing import Optional, Tuple + import torch -from typing import Tuple, Optional -from colossalai.logging import DistributedLogger def colo_model_optimizer_usage(optim) -> Tuple[int, int]: @@ -20,7 +19,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]: def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: - """ + """ Trace the model memory usage. Args: model (torch.nn.Module): a torch model @@ -58,52 +57,3 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: cpu_mem_usage += t_cpu return cuda_mem_usage, cpu_mem_usage - - -class ModelDataTracer(metaclass=SingletonMeta): - """ - A tracer singleton to trace model data usage during runtime. - You have to register a model on the singleton first. - """ - - def __init__(self) -> None: - self._logger = DistributedLogger("ModelDataTracer") - self._model = None - self._opitimizer = None - - def _get_mem_usage(self) -> Tuple[int, int]: - """ - get the memory usage of the model registered. - Returns: - Tuple[int, int]: cuda, cpu mem usage - """ - cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer) - cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model) - return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model - - def register_model(self, model) -> None: - if self._model is not None: - self._logger.warning("ModelDataTracer has already registered a model") - self._model = model - - def register_optimizer(self, optimizer) -> None: - if self._opitimizer is not None: - self._logger.warning("ModelDataTracer has already registered an optimizer") - self._opitimizer = optimizer - - @property - def cpu_usage(self): - _, cpu_usage = self._get_mem_usage() - return cpu_usage - - @property - def cuda_usage(self): - cuda_usage, _ = self._get_mem_usage() - return cuda_usage - - @property - def both_mem_usage(self): - return self._get_mem_usage() - - -GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() diff --git a/colossalai/gemini/ophooks/__init__.py b/colossalai/gemini/ophooks/__init__.py index 9e81ba56d..b65726166 100644 --- a/colossalai/gemini/ophooks/__init__.py +++ b/colossalai/gemini/ophooks/__init__.py @@ -1,4 +1,3 @@ -from .utils import register_ophooks_recursively, BaseOpHook -from ._memtracer_ophook import MemTracerOpHook +from .utils import BaseOpHook, register_ophooks_recursively -__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"] +__all__ = ["BaseOpHook", "register_ophooks_recursively"] diff --git a/colossalai/gemini/ophooks/_memtracer_ophook.py b/colossalai/gemini/ophooks/_memtracer_ophook.py deleted file mode 100644 index 71831f1aa..000000000 --- a/colossalai/gemini/ophooks/_memtracer_ophook.py +++ /dev/null @@ -1,117 +0,0 @@ -import json -import pickle -from pathlib import Path -from colossalai.context.parallel_mode import ParallelMode -import torch -from colossalai.gemini.ophooks import BaseOpHook -from colossalai.registry import OPHOOKS -from colossalai.logging import get_dist_logger -from colossalai.core import global_context as gpc -from typing import Union -import math - - -@OPHOOKS.register_module -class MemTracerOpHook(BaseOpHook): - """ - Collect GPU memory usage information - - Args: - warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50. - refreshrate (int): This parameter decides the frequency of write file, defaults to 10. - data_prefix (string): The prefix of the stats data file, defaults to "memstats". - """ - - def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): - from colossalai.gemini.memory_tracer import AsyncMemoryMonitor - super().__init__() - self.async_mem_monitor = AsyncMemoryMonitor() - self._curiter = 0 - self._logger = get_dist_logger() - self._count = 0 - self._warmup = warmup - self._refreshrate = refreshrate - self._data_prefix = data_prefix - # in distributed environment - if gpc.is_initialized(ParallelMode.GLOBAL): - self._rank = gpc.get_global_rank() - else: - self._rank = 0 - - def _isvalid(self, module) -> bool: - assert isinstance(module, torch.nn.Module) - return module.training - - def _resample(self): - # calculate the average iteration time - total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0]) - avg_it_time = total_time / self.warmup - self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s") - # adjust the sampling power - power: int = round(-math.log(avg_it_time, 10)) + 1 - self._logger.debug(f"the power is {power}") - self.async_mem_monitor.set_interval(power) - - @property - def refreshrate(self) -> int: - return self._refreshrate - - @property - def warmup(self) -> int: - return self._warmup - - @property - def curiter(self) -> int: - return self._curiter - - @property - def valid_iter(self) -> int: - return self.curiter - self.warmup - - def pre_fwd_exec(self, module: torch.nn.Module, *args): - if self._isvalid(module): - self.async_mem_monitor.finish() - self.async_mem_monitor.start() - - def post_fwd_exec(self, module: torch.nn.Module, *args): - if self._isvalid(module): - self.async_mem_monitor.finish() - - def pre_bwd_exec(self, module: torch.nn.Module, input, output): - if self._isvalid(module): - self.async_mem_monitor.finish() - self.async_mem_monitor.start() - - def post_bwd_exec(self, module: torch.nn.Module, input): - if self._isvalid(module): - self.async_mem_monitor.finish() - - def pre_iter(self): - pass - - def post_iter(self): - self.async_mem_monitor.finish() - # in the warmup stage - if self.curiter < self.warmup: - pass - # adjust the sampling rate - elif self.curiter == self.warmup: - # use adaptive sample rate - self._resample() - # record data to log file - else: - # every `refreshrate` times, refresh the file - if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0: - # output file info - self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") - home_dir = Path.home() - with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: - pickle.dump(self.async_mem_monitor.state_dict, f) - self._count += 1 - self._logger.debug(f"data file has been refreshed {self._count} times") - # finish a iteration - self._curiter += 1 - - def save_results(self, data_file: Union[str, Path]): - with open(data_file, "w") as f: - f.write(json.dumps(self.async_mem_monitor.state_dict)) diff --git a/colossalai/gemini/ophooks/_shard_grad_ophook.py b/colossalai/gemini/ophooks/_shard_grad_ophook.py index 582f95802..5115ff74d 100644 --- a/colossalai/gemini/ophooks/_shard_grad_ophook.py +++ b/colossalai/gemini/ophooks/_shard_grad_ophook.py @@ -1,11 +1,12 @@ import torch + from colossalai.registry import OPHOOKS from . import BaseOpHook @OPHOOKS.register_module -class ShardGradHook(BaseOpHook): +class ShardGradMemTracerHook(BaseOpHook): """ A hook to process sharded param before and afther FWD and BWD operator executing. """ diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py new file mode 100644 index 000000000..6d0df4e61 --- /dev/null +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -0,0 +1,145 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor +from colossalai.gemini.tensor_utils import alloc_storage, free_storage +from colossalai.tensor.param_op_hook import ColoParamOpHook + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class GradMemStats(): + + def __init__(self) -> None: + self.unreleased_grad_flag = {} + self.unreleased_grad_volume = 0 + + def clear(self): + self.unreleased_grad_flag.clear() + self.unreleased_grad_volume = 0 + + +class GradMemTracerHook(): + + def __init__(self, grad_stats: GradMemStats): + self.grad_hook_list = [] + self._grad_stats = grad_stats + + def grad_handle(self, p, grad): + assert self._grad_stats.unreleased_grad_flag[p] + free_storage(grad) + self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size() + self._grad_stats.unreleased_grad_flag[p] = False + + def register_grad_hook(self, module: torch.nn.Module): + for p in module.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + self._grad_stats.unreleased_grad_flag[p] = False + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + +class ParamMemTracerHook(ColoParamOpHook): + + def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: + super().__init__() + self._training_phase = TrainingPhase.FORWARD + self._memstats = memstats + self._grad_stats = gradstats + self.mem_monitor = SyncCudaMemoryMonitor() + + def _free_cuda_params(self, params): + for p in params: + if p.data.device.type == "cpu": + raise NotImplementedError("Only free cuda memory") + free_storage(p.data) + + def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): + """ + move params to cuda + + Args: + params (List[torch.nn.Parameter]): target params + + Raises: + NotImplementedError: raise error when param has cpu grad + """ + for p in params: + cur_dev = p.data.device.type + if cur_dev == "cpu": + if p.grad is not None and p.grad.device.type == "cpu": + raise NotImplementedError("Only run in forward propagation") + p.data = torch.empty(p.data.shape, + device="cuda", + dtype=p.data.dtype, + requires_grad=p.data.requires_grad) + elif cur_dev == "cuda": + alloc_storage(p.data) + + def record_model_data_volume(self, params): + """ + get cuda model data used by params + """ + data_volume = self._grad_stats.unreleased_grad_volume + for p in params: + cur_model_data_volume = p.data.numel() * p.data.element_size() + data_volume += cur_model_data_volume + if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: + # add param.grad, actually param.grad is None in this time + data_volume += cur_model_data_volume + if not self._grad_stats.unreleased_grad_flag[p]: + self._grad_stats.unreleased_grad_volume += cur_model_data_volume + self._grad_stats.unreleased_grad_flag[p] = True + # record max non model data used for this Op + self._memstats.record_max_cuda_model_data(data_volume) + + def pre_op(self, params): + max_cuda_used_pre_op = self.mem_monitor.finish() + # record max cuda overall data for prev OP. + self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op) + # record max cuda non model data for prev OP. + self._memstats.calc_max_cuda_non_model_data() + + self._allocate_params_on_cuda(params) + # record max cuda model data for current OP + self.record_model_data_volume(params) + + self.mem_monitor.start() + self._memstats.increase_preop_step(params) + + def post_op(self, params): + self._free_cuda_params(params) + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index ab1988b11..fed1cc298 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -1,22 +1,24 @@ +import functools from abc import ABC, abstractmethod from time import time -from typing import List, Optional, Tuple, Dict +from typing import Dict, List, Optional, Tuple, Type + import torch + +from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.gemini.memory_tracer import ChunkMemStatsCollector from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2 -from typing import Type -import functools -from colossalai.gemini.chunk import Chunk, ChunkManager - class PlacementPolicy(ABC): need_mem_stats: bool = False - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: self.chunk_manager = chunk_manager - self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector + self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -29,7 +31,9 @@ class PlacementPolicy(ABC): class CPUPlacementPolicy(PlacementPolicy): - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy): class CUDAPlacementPolicy(PlacementPolicy): - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) @@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy): _warmup_non_model_data_ratio: float = 0.8 _steady_cuda_cap_ratio: float = 0.9 - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, @@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy): need_mem_stats: bool = False _accessed_memory_boundary = 512 * 1024**2 - def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + def __init__(self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) def evict_tensors(self, @@ -226,7 +236,7 @@ class PlacementPolicyFactory: return PlacementPolicyFactory.policies[policy_name] @staticmethod - def get_polocy_names(): + def get_policy_names(): return tuple(PlacementPolicyFactory.policies.keys()) @staticmethod diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/gemini/tensor_utils.py index f2d69046e..bcc159f99 100644 --- a/colossalai/gemini/tensor_utils.py +++ b/colossalai/gemini/tensor_utils.py @@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor from typing import Union, Tuple +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]: if isinstance(tensor, StatefulTensor): t = tensor.payload diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py index 24f8b60dd..e3575ea12 100644 --- a/colossalai/global_variables.py +++ b/colossalai/global_variables.py @@ -22,7 +22,9 @@ class TensorParallelEnv(object): depth_3d: int = None, input_group_3d=None, weight_group_3d=None, - output_group_3d=None): + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None): self.mode = mode self.vocab_parallel = vocab_parallel self.parallel_input_1d = parallel_input_1d @@ -33,6 +35,8 @@ class TensorParallelEnv(object): self.input_group_3d = input_group_3d self.weight_group_3d = weight_group_3d self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d def save(self): return dict(mode=self.mode, @@ -44,7 +48,9 @@ class TensorParallelEnv(object): depth_3d=self.depth_3d, input_group_3d=self.input_group_3d, weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d) + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d) tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 42c95729a..8933fc0a3 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,3 +1,7 @@ -from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention +from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -__all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"] +__all__ = [ + "LayerNorm", + "FusedScaleMaskSoftmax", + "MultiHeadAttention", +] diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index a35158b72..8f857ff5d 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,3 +1,3 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .scaled_softmax import FusedScaleMaskSoftmax from .multihead_attention import MultiHeadAttention +from .scaled_softmax import FusedScaleMaskSoftmax diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp index a687adc7b..94f132521 100644 --- a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp +++ b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp @@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, - const int bias_correction, - const float weight_decay); + const int bias_correction, const float weight_decay, + const float div_scale); void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, @@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes and apply update for LAMB optimizer"); m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu index 891f23e4e..afd34bb96 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu @@ -28,7 +28,7 @@ struct AdamFunctor { int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float epsilon, const float lr, - adamMode_t mode, const float decay) { + adamMode_t mode, const float decay, const float div_scale) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -79,6 +79,8 @@ struct AdamFunctor { } #pragma unroll for (int ii = 0; ii < ILP; ii++) { + if (div_scale > 0) r_g[ii] /= div_scale; + if (mode == ADAM_MODE_0) { // L2 r_g[ii] = r_g[ii] + (decay * r_p[ii]); r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; @@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, - const int bias_correction, - const float weight_decay) { + const int bias_correction, const float weight_decay, + const float div_scale) { using namespace at; // Handle bias correction mode @@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor(), beta1, beta2, bias_correction1, bias_correction2, epsilon, - lr, (adamMode_t)mode, weight_decay);) + lr, (adamMode_t)mode, weight_decay, div_scale);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp index b02556f79..d08f3dbc7 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -2,8 +2,14 @@ #include #include +#include +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) +#include +#else #include +#endif #include #include "context.h" diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h index 70b3419d8..6505eb31f 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -4,8 +4,15 @@ #include #include #include +#include +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) +#include +#else #include +#endif + #include #include @@ -157,4 +164,4 @@ class MultiHeadAttention { c10::intrusive_ptr pg; int pg_size; -}; \ No newline at end of file +}; diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py new file mode 100644 index 000000000..7bd646d39 --- /dev/null +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -0,0 +1,525 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) +""" + +import math +import os +import subprocess + +import torch + + +def triton_cuda_check(): + cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") + cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() + cuda_version = cuda_version.split('release ')[1] + cuda_version = cuda_version.split(',')[0] + cuda_version = cuda_version.split('.') + if len(cuda_version) == 2 and \ + (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \ + int(cuda_version[0]) > 11: + return True + return False + + +try: + import triton + import triton.language as tl + if triton_cuda_check(): + HAS_TRITON = True + else: + print("triton requires cuda >= 11.4") + HAS_TRITON = False +except ImportError: + print('please install triton from https://github.com/openai/triton') + HAS_TRITON = False +try: + from flash_attn.flash_attention import FlashAttention + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_func, + flash_attn_unpadded_kvpacked_func, + flash_attn_unpadded_qkvpacked_func, + ) + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + print('please install flash_attn from https://github.com/HazyResearch/flash-attention') + +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + HAS_MEM_EFF_ATTN = False + print('please install xformers from https://github.com/facebookresearch/xformers') + +if HAS_TRITON: + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + L, + M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + @triton.jit + def _bwd_preprocess( + Out, + DO, + L, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + @triton.jit + def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + M, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + num_block, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + class _TritonFlashAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + L, + m, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( + o, + do, + l, + do_scaled, + delta, + BLOCK_M=ctx.BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 + _bwd_kernel[(ctx.grid[1],)]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + l, + m, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, + BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + return dq, dk, dv, None + + def triton_flash_attention(q, k, v, sm_scale): + """ + Arguments: + q: (batch, nheads, seq, headdim) + k: (batch, nheads, seq, headdim) + v: (batch, nheads, seq, headdim) + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (batch, nheads, seq, headdim) + """ + if HAS_TRITON: + return _TritonFlashAttention.apply(q, k, v, sm_scale) + else: + raise RuntimeError("Triton kernel requires CUDA 11.4+!") + + +if HAS_FLASH_ATTN: + + from einops import rearrange + + class MaskedFlashAttention(torch.nn.Module): + + def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None: + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size), + attention_dropout=attention_dropout) + + def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False): + if attention_mask.dtype is not torch.bool: + attention_mask = attention_mask.bool() + qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads) + context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal) + context = rearrange(context, 'b s h d -> b s (h d)') + return context + + def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): + """ + Arguments: + qkv: (batch * seqlen, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func(qkv, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=sm_scale, + causal=causal) + return out + + def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + kv: (batch * kv_seqlen, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, + sm_scale, causal) + return out + + def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + k: (batch * kv_seqlen, nheads, headdim) + v: (batch * kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, + step=kv_seqlen, + dtype=torch.int32, + device=k.device) + return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, + causal) + + +if HAS_MEM_EFF_ATTN: + + from einops import rearrange + from xformers.ops.fmha import LowerTriangularMask + + class MemoryEfficientAttention(torch.nn.Module): + + def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0): + super().__init__() + attention_head_size = hidden_size // num_attention_heads + self.scale = 1 / attention_head_size**0.5 + self.dropout = attention_dropout + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor): + context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale) + context = rearrange(context, 'b s h d -> b s (h d)') + return context diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 38e95e2f8..4be336388 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -3,14 +3,11 @@ with some changes. """ import numbers -import torch -from torch.nn.parameter import Parameter -from torch.nn import init -from torch.cuda.amp import custom_fwd, custom_bwd -import importlib -global colossal_layer_norm_cuda -colossal_layer_norm_cuda = None +import torch +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import init +from torch.nn.parameter import Parameter class FusedLayerNormAffineFunction(torch.autograd.Function): @@ -18,14 +15,18 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): + try: + from colossalai._C import layer_norm + except ImportError: + from colossalai.kernel.op_builder.layernorm import LayerNormBuilder + layer_norm = LayerNormBuilder().load() ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, - ctx.eps) + output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -33,11 +34,16 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, grad_output): + try: + from colossalai._C import layer_norm + except ImportError: + from colossalai.kernel.op_builder.layernorm import LayerNormBuilder + layer_norm = LayerNormBuilder().load() input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = colossal_layer_norm_cuda.backward_affine( + = layer_norm.backward_affine( grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps) @@ -50,13 +56,6 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() - global colossal_layer_norm_cuda - if colossal_layer_norm_cuda is None: - try: - colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda") - except ImportError: - raise RuntimeError('MixedFusedLayerNorm requires cuda extensions') - if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index c93d1cf60..7df53731e 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -1,5 +1,4 @@ import math -import importlib from dataclasses import dataclass import torch @@ -37,21 +36,21 @@ colossal_multihead_attention = None @dataclass class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 presion + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 presion class MultiHeadAttention1DFunc(Function): @staticmethod - def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias, config): + def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, + norm_bias, config): cuda_module = colossal_multihead_attention forward_func = (cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32) @@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function): input = input.to(torch.half) input_mask = input_mask.to(torch.half) - (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, - out_proj_weight, out_proj_bias, norm_weight, norm_bias, - config.training, config.norm_first) + (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) if config.is_grad_enabled and config.training: - ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, - out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias) ctx.config = config return output @@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function): ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) - return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, - grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None) + return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, + grad_norm_weight, grad_norm_bias, None) class MultiHeadAttention(nn.Module): @@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module): layer_id = 0 - def __init__(self, - hidden_size, - nhead, - batch_size, - max_seq_len, - dropout=0.0, - norm_first=False, - fp16=True, - pg=None): + def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): super(MultiHeadAttention, self).__init__() - self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, - dropout, norm_first, fp16) + self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, + fp16) check_config(self.config) self.pg = pg self.pg_size = 1 @@ -145,10 +135,9 @@ class MultiHeadAttention(nn.Module): # Load cuda modules if needed global colossal_multihead_attention if colossal_multihead_attention is None: - try: - colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") - except ImportError: - raise RuntimeError('MultiHeadAttention requires cuda extensions') + from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() + colossal_multihead_attention = multihead_attention # create the layer in cuda kernels. cuda_module = colossal_multihead_attention @@ -215,14 +204,13 @@ class MultiHeadAttention(nn.Module): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size), - :]) + attn_qkvw_global.view(3, hs, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size)]) + attn_qkvb_global.view(3, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -230,9 +218,9 @@ class MultiHeadAttention(nn.Module): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size)]) + self.out_proj_weight.copy_(attn_ow_global[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) else: attn_qkvw = self.in_proj_weight.view(-1, hs) @@ -243,10 +231,7 @@ class MultiHeadAttention(nn.Module): nn.init.xavier_uniform_(self.out_proj_weight, 1.0) def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, - destination=destination, - prefix=prefix, - keep_vars=keep_vars) + destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) return destination def forward(self, hidden_states, encoder_padding_mask): @@ -257,8 +242,7 @@ class MultiHeadAttention(nn.Module): bs, sl, dim = hidden_states.size() if bs * sl > self.config.max_batch_tokens: - raise ValueError( - f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") + raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") if sl > self.config.max_seq_len: raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") if len(encoder_padding_mask.size()) == 1: @@ -266,9 +250,8 @@ class MultiHeadAttention(nn.Module): else: assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, - self.in_proj_weight, self.in_proj_bias, - self.out_proj_weight, self.out_proj_bias, + output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, + self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, self.norm_weight, self.norm_bias, self.config) return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index cb36da8a1..3f0260aae 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -1,9 +1,10 @@ """This code from NVIDIA Megatron with some changes. """ +import enum + import torch import torch.nn as nn -import enum class AttnMaskType(enum.Enum): @@ -22,26 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, scale): - try: - import colossal_scaled_upper_triang_masked_softmax - except ImportError: - raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') + from colossalai.kernel import scaled_upper_triang_masked_softmax scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): - try: - import colossal_scaled_upper_triang_masked_softmax - except ImportError: - raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') + from colossalai.kernel import scaled_upper_triang_masked_softmax softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) return input_grads, None @@ -58,26 +53,28 @@ class ScaledMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, mask, scale): try: - import colossal_scaled_masked_softmax + from colossalai._C import scaled_masked_softmax except ImportError: - raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') + from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): try: - import colossal_scaled_masked_softmax + from colossalai._C import scaled_masked_softmax except ImportError: - raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') + from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @@ -184,8 +181,8 @@ class FusedScaleMaskSoftmax(nn.Module): @staticmethod def get_batch_per_block(sq, sk, b, np): try: - import colossal_scaled_masked_softmax + import colossalai._C.scaled_masked_softmax except ImportError: raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') - return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) + return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index d95905897..aa41f5767 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,5 +1,11 @@ import torch +from colossalai.nn.layer.colossalai_layer import Embedding, Linear +from colossalai.utils import get_current_device + +from .bias_dropout_add import bias_dropout_add_fused_train +from .bias_gelu import bias_gelu_impl + JIT_OPTIONS_SET = False @@ -30,3 +36,44 @@ def set_jit_fusion_options(): torch._C._jit_override_can_fuse_on_gpu(True) JIT_OPTIONS_SET = True + + +def warmup_jit_fusion(batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32): + """ Compilie JIT functions before the main training steps """ + + embed = Embedding(vocab_size, hidden_size).to(get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + + x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = embed(x) + y, y_bias = linear_1(x) + z, z_bias = linear_2(y) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + for _ in range(10): + bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias.requires_grad, input_.requires_grad = bias_grad, input_grad + bias_gelu_impl(input_, bias) + + # Warmup fused bias+dropout+add + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + for _ in range(10): + input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + bias_dropout_add_fused_train(input_, bias, residual, dropout_rate) + + torch.cuda.empty_cache() diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder new file mode 120000 index 000000000..db4f9c335 --- /dev/null +++ b/colossalai/kernel/op_builder @@ -0,0 +1 @@ +../../op_builder \ No newline at end of file diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index acfc73c2d..af7b7de54 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -1,24 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import colossalai +import inspect import logging from pathlib import Path -from typing import Union, List -import inspect +from typing import List, Union +import colossalai from colossalai.context.parallel_mode import ParallelMode -try: - from rich.logging import RichHandler - _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, - format=_FORMAT, - handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)]) -except ImportError: - _FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s' - logging.basicConfig(level=logging.INFO, format=_FORMAT) - class DistributedLogger: """This is a distributed event logger class essentially based on :class:`logging`. @@ -40,7 +30,7 @@ class DistributedLogger: Args: name (str): The name of the logger. - + Returns: DistributedLogger: A DistributedLogger object """ @@ -55,8 +45,23 @@ class DistributedLogger: raise Exception( 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') else: + handler = None + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + try: + from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) + handler.setFormatter(formatter) + except ImportError: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + self._name = name self._logger = logging.getLogger(name) + self._logger.setLevel(logging.INFO) + if handler is not None: + self._logger.addHandler(handler) + self._logger.propagate = False + DistributedLogger.__instances[name] = self @staticmethod @@ -119,7 +124,7 @@ class DistributedLogger: # add file handler file_handler = logging.FileHandler(path, mode) file_handler.setLevel(getattr(logging, level)) - formatter = logging.Formatter(_FORMAT) + formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py index 91fc0da55..910ad2031 100644 --- a/colossalai/nn/__init__.py +++ b/colossalai/nn/__init__.py @@ -1,6 +1,6 @@ +from ._ops import * from .layer import * from .loss import * from .lr_scheduler import * from .metric import * from .optimizer import * -from ._ops import * diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py index 945505b74..4991ad9a2 100644 --- a/colossalai/nn/_ops/__init__.py +++ b/colossalai/nn/_ops/__init__.py @@ -1,8 +1,9 @@ -from .linear import colo_linear -from .element_wise import * -from .layernorm import colo_layernorm -from .loss import colo_cross_entropy -from .embedding import colo_embedding from .addmm import colo_addmm +from .batch_norm import colo_batch_norm +from .element_wise import * +from .embedding import colo_embedding from .embedding_bag import colo_embedding_bag -from .view import colo_view \ No newline at end of file +from .layernorm import colo_layernorm +from .linear import colo_linear +from .loss import colo_cross_entropy +from .view import colo_view diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index ce7e8bef6..fe2eb0c99 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor, mat2: ColoTensor, beta: Number = 1, alpha: Number = 1, - *args) -> ColoTensor: + **kargs) -> ColoTensor: """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. This method computes a linear. """ @@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor, assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.from_torch_tensor( - tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha), + tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), spec=ColoTensorSpec(mat2.get_process_group())) elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if mat2.is_shard_1drow() and input_tensor.is_replicate(): diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/nn/_ops/batch_norm.py new file mode 100644 index 000000000..54ecc88f4 --- /dev/null +++ b/colossalai/nn/_ops/batch_norm.py @@ -0,0 +1,33 @@ +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor + + +@colo_op_impl(F.batch_norm) +def colo_batch_norm( + input: GeneralTensor, + running_mean: Optional[GeneralTensor], + running_var: Optional[GeneralTensor], + weight: Optional[GeneralTensor] = None, + bias: Optional[GeneralTensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +): + assert isinstance(weight, ColoTensor) + running_mean = running_mean.detach() + running_var = running_var.detach() + + input = convert_to_colo_tensor(input, weight.get_process_group()) + bias = convert_to_colo_tensor(bias, weight.get_process_group()) + input = input.redistribute(ReplicaSpec()) + bias = bias.redistribute(ReplicaSpec()) + + output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) + output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) + return output diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index 462670e72..2de51e24a 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -1,9 +1,11 @@ import torch import torch.nn.functional as F from torch import Tensor -from colossalai.tensor.op_wrapper import colo_op_impl + from colossalai.tensor import ColoTensor, ColoTensorSpec -from ._utils import GeneralTensor +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor def register_elementwise_op(op): @@ -15,17 +17,32 @@ def register_elementwise_op(op): as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. This method computes on either a normal tensor or a sharded tensor. """ + if 'inplace' in kwargs: + # TODO(jiaruifang) inplace will cause bugs + input_tensor = input_tensor.clone() + return op(input_tensor, *args, **kwargs) + else: + output = op(input_tensor, *args, **kwargs) + # return output + if isinstance(input_tensor, ColoTensor): + if isinstance(output, str): + return output + if not isinstance(output, torch.Tensor): + raise NotImplementedError + return ColoTensor.from_torch_tensor(output, + spec=ColoTensorSpec(input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) - output = op(input_tensor, *args, **kwargs) - if isinstance(input_tensor, ColoTensor): - if isinstance(output, str): - return output - if not isinstance(output, torch.Tensor): - raise NotImplementedError - return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) +# @colo_op_impl(torch.relu_) +# def elementwise_op(input_tensor): +# torch.relu_(input_tensor.data) +# return input_tensor + +# @colo_op_impl(Tensor.add_) +# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): +# input_tensor = input_tensor.data.add_(*args, **kwargs) +# return input_tensor # Tensor op register_elementwise_op(Tensor.abs) diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 8835574de..2f2088c61 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -1,11 +1,13 @@ -import torch.nn.functional as F -from typing import Optional -from ._utils import GeneralTensor, convert_to_colo_tensor -from colossalai.tensor.op_wrapper import colo_op_impl -from ._utils import reduce_input, reduce_grad -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec -from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor.sharding_spec import ShardingSpec + +from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': @@ -155,17 +157,15 @@ def _new_colo_linear_imp(input_tensor: GeneralTensor, def _has_sharding_spec(tensor): """ - A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is + A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is set as the attribute `sharding_spec` on a tensor. """ return hasattr(tensor, 'sharding_spec') @colo_op_impl(F.linear) -def colo_linear(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': +def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': if _has_sharding_spec(weight): - return _new_colo_linear_imp(input_tensor, weight, bias) + return _new_colo_linear_imp(input, weight, bias) else: - return colo_linear_imp(input_tensor, weight, bias) + return colo_linear_imp(input, weight, bias) diff --git a/colossalai/nn/graph/__init__.py b/colossalai/nn/graph/__init__.py deleted file mode 100644 index 0cfecf8b4..000000000 --- a/colossalai/nn/graph/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .utils import register_colo_graph -from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode - -__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode'] \ No newline at end of file diff --git a/colossalai/nn/graph/graph_node.py b/colossalai/nn/graph/graph_node.py deleted file mode 100644 index 32653ad98..000000000 --- a/colossalai/nn/graph/graph_node.py +++ /dev/null @@ -1,96 +0,0 @@ -from colossalai.tensor import ColoTensor -from colossalai.context.singleton_meta import SingletonMeta - - -class GraphGlobalEnv(metaclass=SingletonMeta): - - def __init__(self) -> None: - self.graph_building = False - self.graph_node_list = [] - self.node_id = -1 - - def get_node_id(self): - self.node_id += 1 - return self.node_id - - def add_graph_node(self, node): - self.graph_node_list.append(node) - - -class GraphContext(): - """ - - Building the computing graph under the context - - >>> with GraphContext(): - >>> output = model(colo_input_tensor) - """ - graph_nodes = [] - - def __enter__(self): - GraphGlobalEnv().graph_building = True - GraphGlobalEnv().graph_node_list = [] - - def __exit__(self, *exc_info): - GraphGlobalEnv().graph_building = False - GraphGlobalEnv().node_id = -1 - self.graph_nodes = GraphGlobalEnv().graph_node_list - - -class GraphNode(object): - - def __init__(self) -> None: - self.prev_nodes = [] - self.post_nodes = [] - self.id = GraphGlobalEnv().get_node_id() - - def add_prev_node(self, node): - if GraphGlobalEnv().graph_building: - self.prev_nodes.append(node) - - def add_post_node(self, node): - if GraphGlobalEnv().graph_building: - self.post_nodes.append(node) - - def post_node_empty(self) -> bool: - return len(self.post_nodes) == 0 - - -class GraphOpNode(GraphNode): - - def __init__(self, op_type, param_list) -> None: - super().__init__() - self._op_type = op_type - self._param_list = param_list - GraphGlobalEnv().add_graph_node(self) - - def add_prev_tensor(self, colo_tensor: ColoTensor): - r""" - Link the current graph op node to previous graph op. - Op1 <- Activation (colo_tensor) Op2 - Op1 <- Op2 - """ - if GraphGlobalEnv().graph_building: - assert isinstance(colo_tensor, ColoTensor) - if colo_tensor._graph_node is None: - colo_tensor._graph_node = GraphNode() - prev_ops = colo_tensor._graph_node.prev_nodes - for op_node in prev_ops: - self.add_prev_node(op_node) - op_node.add_post_node(self) - - def add_post_tensor(self, colo_tensor: ColoTensor): - """ - Op <- Activation (colo_tensor) - """ - if GraphGlobalEnv().graph_building: - assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}' - if colo_tensor._graph_node is None: - colo_tensor._graph_node = GraphNode() - - colo_tensor._graph_node.add_prev_node(self) - - def print(self): - print( - f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}' - ) diff --git a/colossalai/nn/graph/utils.py b/colossalai/nn/graph/utils.py deleted file mode 100644 index 1070319ca..000000000 --- a/colossalai/nn/graph/utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import torch -from colossalai.tensor import ColoTensor -from typing import Callable, List -from colossalai.nn._ops._utils import convert_to_colo_tensor - - -def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable: - """register_colo_graph - Register a Op (Layer) to ColoGraph. - Recoders the input args in types of ColoTensor to the Graph. - - Args: - func (Callable): a function implements the Op. - - Returns: - Callable: wrapper function. - """ - - def register_colo_graph_decorator(func): - from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv - - @functools.wraps(func) - def wrapper(*args, **kwargs): - param_list = [] - input_list = [] - # TODO(jiaruifang) find the pg - for idx, arg in enumerate(args): - if isinstance(arg, torch.Tensor) and idx in input_pos: - input_list.append(convert_to_colo_tensor(arg)) - if isinstance(arg, torch.Tensor) and idx in param_pos: - param_list.append(convert_to_colo_tensor(arg)) - # building the computing graph, inputs -> op - if GraphGlobalEnv().graph_building: - cur_op_node = GraphOpNode('linear', param_list) - # TODO supports a list of ColoTensor as args - if len(input_list) > 0: - cur_op_node.add_prev_tensor(input_list[0]) - - outputs = func(*args, **kwargs) - - # building the computing graph, op -> output - if GraphGlobalEnv().graph_building: - # TODO supports a list of ColoTensor as args - if isinstance(outputs[0], ColoTensor): - cur_op_node.add_post_tensor(outputs[0]) - return outputs - - return wrapper - - return register_colo_graph_decorator diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py index f3f35838b..3e0c6e285 100644 --- a/colossalai/nn/layer/colossalai_layer/linear.py +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -1,147 +1,141 @@ -import math -import inspect -from typing import Callable - -from colossalai.utils import get_current_device -from torch import dtype, nn - -from ... import init as init -from ..parallel_1d import * -from ..parallel_2d import * -from ..parallel_2p5d import * -from ..parallel_3d import * -from ..utils import get_tensor_parallel_mode -from ..vanilla import * -from ._utils import ColossalaiModule - -_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} - -_parallel_classifier = { - None: VanillaClassifier, - '1d': Classifier1D, - '2d': Classifier2D, - '2.5d': Classifier2p5D, - '3d': Classifier3D -} - -_vocab_parallel_classifier = { - '1d': VocabParallelClassifier1D, - '2d': VocabParallelClassifier2D, - '2.5d': VocabParallelClassifier2p5D, - '3d': VocabParallelClassifier3D -} - - -class Linear(ColossalaiModule): - """Linear layer of colossalai. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - Note: ``kwargs`` would contain different parameters when you use different parallelisms. - - The ``kwargs`` should contain parameters below: - :: - - Linear1D: - gather_output: bool (optional, default to be false) - skip_bias_add: bool (optional, default to be false) - Linear2D: - skip_bias_add: bool (optional, default to be false) - Linear2p5D: - skip_bias_add: bool (optional, default to be false) - Linear3D: - None - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None: - layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) - weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features) - if layer.bias is not None: - bias_initializer(layer.bias, fan_in=in_features) - else: - linear_cls = _parallel_linear[tensor_parallel] - gather_output = kwargs.pop('gather_output', None) - if 'gather_output' in inspect.signature( - linear_cls.__init__).parameters.keys(): # gather_out arg is available - kwargs['gather_output'] = gather_output - layer = linear_cls( - in_features, - out_features, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs, - ) - super().__init__(layer) - - -class Classifier(ColossalaiModule): - """Classifier layer of colossalai. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - vocab_parallel_limit: int = 2048) -> None: - tensor_parallel = get_tensor_parallel_mode() - if num_classes <= vocab_parallel_limit or tensor_parallel is None: - layer = _parallel_classifier[tensor_parallel]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) - else: - layer = _vocab_parallel_classifier[tensor_parallel]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) - super().__init__(layer) +import inspect +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.utils import get_current_device + +from ... import init as init +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * +from ._utils import ColossalaiModule + +_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} + +_parallel_classifier = { + None: VanillaClassifier, + '1d': Classifier1D, + '2d': Classifier2D, + '2.5d': Classifier2p5D, + '3d': Classifier3D +} + +_vocab_parallel_classifier = { + '1d': VocabParallelClassifier1D, + '2d': VocabParallelClassifier2D, + '2.5d': VocabParallelClassifier2p5D, + '3d': VocabParallelClassifier3D +} + + +class Linear(ColossalaiModule): + """Linear layer of colossalai. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + Note: ``kwargs`` would contain different parameters when you use different parallelisms. + + The ``kwargs`` should contain parameters below: + :: + + Linear1D: + gather_output: bool (optional, default to be false) + skip_bias_add: bool (optional, default to be false) + Linear2D: + skip_bias_add: bool (optional, default to be false) + Linear2p5D: + skip_bias_add: bool (optional, default to be false) + Linear3D: + None + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + linear_cls = _parallel_linear[tensor_parallel] + gather_output = kwargs.pop('gather_output', None) + if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs['gather_output'] = gather_output + layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + super().__init__(layer) + + +class Classifier(ColossalaiModule): + """Classifier layer of colossalai. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048) -> None: + tensor_parallel = get_tensor_parallel_mode() + if num_classes <= vocab_parallel_limit or tensor_parallel is None: + layer = _parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + layer = _vocab_parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + super().__init__(layer) diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index dbf264297..37f31c167 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -1,153 +1,175 @@ -import torch -import torch.distributed as dist -from torch import Tensor -from typing import Any, Tuple, Optional -from torch.distributed import ProcessGroup - -COL_MOE_KERNEL_FLAG = False -try: - import colossal_moe_cuda - - COL_MOE_KERNEL_FLAG = True -except ImportError: - print("If you want to activate cuda mode for MoE, please install with cuda_ext!") - - -class AllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.unsqueeze(0) - - buffer_shape = (comm_size,) + inputs.shape - outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None - - -class ReduceScatter(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.squeeze(0) - - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - - output_shape = inputs.shape[1:] - outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.comm_grp), None - - -class AllToAll(torch.autograd.Function): - """Dispatches input tensor [e, c, h] to all experts by all_to_all_single - operation in torch.distributed. - """ - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - if dist.get_world_size(group) == 1: - return inputs - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output - - @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None - - -class MoeDispatch(torch.autograd.Function): - - @staticmethod - def forward(ctx, tokens, mask, dest_idx, ec): - s = tokens.size(0) - h = tokens.size(1) - - expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx) - - ctx.save_for_backward(mask, dest_idx) - ctx.s = s - ctx.h = h - ctx.ec = ec - - return expert_input - - @staticmethod - def backward(ctx, output_grad): - mask, dest_idx = ctx.saved_tensors - d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) - return d_tokens, None, None, None - - -class MoeCombine(torch.autograd.Function): - - @staticmethod - def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): - assert logits.dtype == torch.float32 - - s = logits.size(0) - e = logits.size(1) - c = ec // e - h = expert_tokens.size(-1) - - fp16_flag = (expert_tokens.dtype == torch.float16) - cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) - output = ctokens.to(torch.float16) if fp16_flag else ctokens - - ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) - ctx.s = s - ctx.e = e - ctx.c = c - ctx.h = h - ctx.fp16_flag = fp16_flag - - return output - - @staticmethod - def backward(ctx, tokens_grad): - expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad - cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, - mask, dest_idx) - d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert - - return d_expert, d_logits, None, None, None - - -def moe_cumsum(inputs: Tensor): - dim0 = inputs.size(0) - flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - return colossal_moe_cuda.cumsum_sub_one(inputs) - else: - return torch.cumsum(inputs, dim=0) - 1 +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +COL_MOE_KERNEL_FLAG = False + +try: + from colossalai._C import moe +except: + moe = None + + +def build_moe_if_not_prebuilt(): + # load moe kernel during runtime if not pre-built + global moe + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + + global moe + + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0) + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + dist.all_gather(buffer_list, inputs, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + + +class ReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0) + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + if dist.get_world_size(group) == 1: + return inputs + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + + +class MoeDispatch(torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + + expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + + return expert_input + + @staticmethod + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + + @staticmethod + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + + fp16_flag = (expert_tokens.dtype == torch.float16) + cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + output = ctokens.to(torch.float16) if fp16_flag else ctokens + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.fp16_flag = fp16_flag + + return output + + @staticmethod + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ + else tokens_grad + cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens + d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) + d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and COL_MOE_KERNEL_FLAG: + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + return moe.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 7944598b7..394334558 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,4 +1,6 @@ import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc try: import fused_mix_prec_layer_norm_cuda @@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index fd26f67e8..e96abd87e 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -7,6 +7,9 @@ from typing import Callable, Tuple import torch import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc @@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env from colossalai.kernel import LayerNorm from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn.parameter import Parameter -from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, - split_forward_gather_backward) +from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from ._operation import linear_with_async_comm +from ._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass @LAYERS.register_module @@ -59,12 +77,11 @@ class Linear1D(ColossalaiModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): parallel_input = get_parallel_input() - if not parallel_input: + if not parallel_input and not gather_output: layer = Linear1D_Col(in_features, out_features, bias=bias, dtype=dtype, - gather_output=gather_output, skip_bias_add=skip_bias_add, weight_initializer=weight_initializer, bias_initializer=bias_initializer) @@ -96,8 +113,21 @@ class LayerNorm1D(ColossalaiModule): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype) + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) super().__init__(norm) def _load_from_state_dict(self, state_dict, prefix, *args): @@ -519,11 +549,12 @@ class Linear1D_Col(ParallelLayer): 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) if self.gather_output: # All-gather across the partitions. output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -565,9 +596,12 @@ class Linear1D_Row(ParallelLayer): parallel_input: bool = True, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): super().__init__() + self.stream_chunk_num = stream_chunk_num + # Keep input parameters self.in_features = in_features self.out_features = out_features @@ -585,6 +619,9 @@ class Linear1D_Row(ParallelLayer): factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: @@ -594,6 +631,9 @@ class Linear1D_Row(ParallelLayer): self._set_tensor_parallel_attributes() set_parallel_input(False) + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -664,9 +704,26 @@ class Linear1D_Row(ParallelLayer): input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - output_parallel = F.linear(input_, self.weight) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) if not self.skip_bias_add: if self.bias is not None: output = output + self.bias diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py old mode 100644 new mode 100755 index eb045f2b4..5dc9a2428 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -4,121 +4,112 @@ from typing import Optional, Tuple import torch -from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from ._utils import get_parallel_mode_from_env + +from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +from ._utils import get_parallel_mode_from_env, push_async_grad class _Linear3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - input_dim: int = 0, - weight_dim: int = -1, - output_dim: int = 0) -> Tensor: - ctx.use_bias = bias is not None - - input_ = all_gather(input_, input_dim, input_parallel_mode) - weight = all_gather(weight, weight_dim, weight_parallel_mode) - ctx.save_for_backward(input_, weight) - - output = torch.matmul(input_, weight) - output = reduce_scatter(output, output_dim, output_parallel_mode) - - if bias is not None: - output += bias - + def forward( + ctx, + input_: Tensor, + weight: Tensor, + weight_id: int, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id ctx.input_parallel_mode = input_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode ctx.output_parallel_mode = output_parallel_mode - ctx.input_dim = input_dim - ctx.weight_dim = weight_dim - ctx.output_dim = output_dim + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + return output @staticmethod @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors - with torch.no_grad(): - output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode) + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) - async_ops = list() + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) - input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True) - async_ops.append(op) + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) + input_op.wait() - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) - else: - bias_grad = None - - for op in async_ops: - if op is not None: - op.wait() - - return input_grad, weight_grad, bias_grad, None, None, None, None, None, None + return input_grad, weight_grad, None, None, None, None -def linear_3d(input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - input_dim: int = 0, - weight_dim: int = -1, - output_dim: int = 0) -> Tensor: +def linear_3d( + input_: Tensor, + weight: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: r"""Linear layer for 3D parallelism. Args: input_ (:class:`torch.tensor`): input matrix. weight (:class:`torch.tensor`): matrix of weight. - bias (:class:`torch.tensor`): matrix of bias. input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - input_dim (int, optional): dimension of input, defaults to 0. - weight_dim (int, optional): dimension of weight, defaults to -1. - output_dim (int, optional): dimension of output, defaults to 0. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode, - input_dim, weight_dim, output_dim) + return _Linear3D.apply( + input_, + weight, + id(weight), + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) class _Classifier3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: ctx.use_bias = bias is not None + ctx.weight_id = weight_id - ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) - src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] + src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] weight = broadcast(weight, src_rank, input_parallel_mode) ctx.save_for_backward(input_, weight) @@ -126,6 +117,7 @@ class _Classifier3D(torch.autograd.Function): output = all_reduce(output, output_parallel_mode) if bias is not None: + ctx.bias_id = bias_id output += bias ctx.src_rank = src_rank @@ -138,37 +130,36 @@ class _Classifier3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors - with torch.no_grad(): - async_ops = list() + weight_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) + if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): + weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + else: + weight_grad = None - weight_grad = torch.matmul( - output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) - weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) - if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): - weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) - else: - weight_grad = None + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) - else: - bias_grad = None + input_grad = torch.matmul(output_grad, weight) - input_grad = torch.matmul(output_grad, weight) - - for op in async_ops: - if op is not None: - op.wait() - - return input_grad, weight_grad, bias_grad, None, None, None, None, None, None + return input_grad, weight_grad, bias_grad, None, None, None, None, None -def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: +def classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: r"""3D parallel classifier. Args: @@ -183,33 +174,166 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_ The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode) + return _Classifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _VocabParallelClassifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_op.wait() + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def vocab_parallel_classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D vocab parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _VocabParallelClassifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +@torch.jit.script +def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): + mu = x - mean + var = sqr_mean - mean**2 + sigma = torch.sqrt(var + eps) + z = mu / sigma + output = weight * z + bias + + return output, mu, sigma + + +@torch.jit.script +def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): + # dbias, dweight = grad, grad * mu / sigma + dz = grad * weight + dmu = dz / sigma + dvar = dz * mu * (-0.5) * sigma**(-3) + dmean = -dmu + dvar = torch.sum(dvar, -1, keepdim=True) + dmean = torch.sum(dmean, -1, keepdim=True) + + return dmu, dmean, dvar class _Layernorm3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: - mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape - mu = input_ - mean - var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape - sigma = torch.sqrt(var + eps) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Tensor, + weight_id: int, + bias_id: int, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.bias_id = bias_id + + sum_ = torch.sum(input_, dim=-1, keepdim=True) + sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) + mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape + + output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) ctx.save_for_backward(mu, sigma, weight) - z = mu / sigma - output = weight * z - if bias is not None: - output = output + bias - - ctx.use_bias = bias is not None ctx.normalized_shape = normalized_shape - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode ctx.output_parallel_mode = output_parallel_mode + ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode return output @@ -217,34 +341,31 @@ class _Layernorm3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: mu, sigma, weight = ctx.saved_tensors - with torch.no_grad(): - weight_grad = output_grad * mu / sigma - if ctx.use_bias: - bias_grad = output_grad - weight_grad = torch.stack([bias_grad, weight_grad]).contiguous() - else: - bias_grad = None - weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1])) - weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode) - weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode) - if ctx.use_bias: - bias_grad, weight_grad = weight_grad[0], weight_grad[1] - dz = output_grad * weight - dvar = dz * mu * (-0.5) * sigma**(-3) - dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode) - dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape - dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode) + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - input_grad = dz / sigma + dvar * 2 * mu / \ - ctx.normalized_shape + dmean / ctx.normalized_shape + dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) + dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) + input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape - return input_grad, weight_grad, bias_grad, None, None, None, None, None + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None -def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: +def layernorm_3d( + input_: Tensor, + weight: Tensor, + bias: Tensor, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, +) -> Tensor: r"""3D parallel Layernorm. Args: @@ -257,16 +378,24 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float): a value added to the denominator for numerical stability - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode, - output_parallel_mode) + return _Layernorm3D.apply( + input_, + weight, + bias, + id(weight), + id(bias), + normalized_shape, + eps, + output_parallel_mode, + input_x_weight_parallel_mode, + ) def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: @@ -315,17 +444,12 @@ def split_batch_3d(input_: Tensor, The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ - dim_size = input_.size(dim) + if input_.size(dim) <= 1: + return input_ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_world_size = gpc.get_world_size(weight_parallel_mode) input_world_size = gpc.get_world_size(input_parallel_mode) - - assert dim_size % (input_world_size*weight_world_size) == 0, \ - f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).' - - if input_.size(dim) <= 1: - return input_ output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() return output @@ -464,47 +588,3 @@ def reduce_by_batch_3d(tensor: Tensor, in `parallel_mode `_ """ return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) - - -class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function): - r"""broadcast weight from diagonal. - - Args: - input_ (:class:`torch.tensor`): input matrix. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: - ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) - src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] - output = broadcast(input_, src_rank, input_parallel_mode) - ctx.src_rank = src_rank - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode - ctx.output_parallel_mode = output_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode) - if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): - input_grad = all_reduce(input_grad, ctx.weight_parallel_mode) - else: - input_grad = None - return input_grad, None, None, None - - -def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: - return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode, - output_parallel_mode) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py index 0622164cd..364191a79 100644 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -1,8 +1,12 @@ -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D -from colossalai.context.parallel_mode import ParallelMode +from collections import OrderedDict +from functools import partial + +import torch +from torch import Tensor + +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env -from torch import Tensor def get_depth_from_env() -> int: @@ -17,30 +21,17 @@ def get_depth_from_env() -> int: def get_parallel_mode_from_env(group): - assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \ + assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ f'{group} is not valid for 3D tensor parallelism.' return getattr(env, group) -def get_last_group(a, b): - mapping = { - ParallelMode.PARALLEL_3D_INPUT: 'A', - ParallelMode.PARALLEL_3D_WEIGHT: 'B', - ParallelMode.PARALLEL_3D_OUTPUT: 'C', - } - - res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b])) - - if res == 'A': - return ParallelMode.PARALLEL_3D_INPUT - elif res == 'B': - return ParallelMode.PARALLEL_3D_WEIGHT - elif res == 'C': - return ParallelMode.PARALLEL_3D_OUTPUT - - def swap_in_out_group(): env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d + env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( + env.output_x_weight_group_3d, + env.input_x_weight_group_3d, + ) def dbg_check_shape(tensor: Tensor, shape: tuple): @@ -49,3 +40,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple): print(tensor.shape) assert tensor.shape == shape, \ '{} does not match {}'.format(tensor.shape, shape) + + +class AsyncGradientBucket(object): + + def __init__(self): + self.bucket = OrderedDict() + + def __len__(self): + return len(self.bucket) + + def push(self, async_op, grad_tensor, param_id): + self.bucket[param_id] = tuple((async_op, grad_tensor)) + return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) + + def pop(self, param_id): + grad = None + if param_id in self.bucket: + op, grad = self.bucket.pop(param_id) + if op is not None: + op.wait() + return grad + + def synchronize(self, params): + for p in params: + i = id(p) + if i in self.bucket: + op, grad = self.bucket.pop(i) + if op is not None: + op.wait() + p.grad.add_(grad) + + +_async_grad_bucket = AsyncGradientBucket() + + +def push_async_grad(op, grad, param_id): + return _async_grad_bucket.push(op, grad, param_id) + + +def pop_async_grad(param_id): + return _async_grad_bucket.pop(param_id) + + +def _async_grad_hook(grad, param_id): + grad.add_(pop_async_grad(param_id)) + return grad + + +def register_async_grad_hook(param): + param.register_hook(partial(_async_grad_hook, param_id=id(param))) + + +def synchronize(params=list()): + _async_grad_bucket.synchronize(params) + torch.cuda.default_stream().synchronize() + if len(_async_grad_bucket) > 0: + raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 037a09763..99b0c3f8b 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -5,24 +5,36 @@ from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + from colossalai.communication import all_reduce, broadcast -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d, - linear_3d, reduce_scatter_tensor_3d, split_tensor_3d) -from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group +from ._operation import ( + all_gather_tensor_3d, + classifier_3d, + layernorm_3d, + linear_3d, + reduce_scatter_tensor_3d, + split_batch_3d, + split_tensor_3d, + vocab_parallel_classifier_3d, +) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group @LAYERS.register_module @@ -45,7 +57,8 @@ class LayerNorm3D(ParallelLayer): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() self.normalized_shape = normalized_shape self.normalized_shape_per_partition = divide(normalized_shape, self.depth) @@ -58,6 +71,7 @@ class LayerNorm3D(ParallelLayer): else: self.bias = None self.variance_epsilon = eps + self.reset_parameters() self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: @@ -67,8 +81,10 @@ class LayerNorm3D(ParallelLayer): def reset_parameters(self) -> None: init.ones_()(self.weight) + register_async_grad_hook(self.weight) if self.bias is not None: init.zeros_()(self.bias) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -134,8 +150,15 @@ class LayerNorm3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, - self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) + return layernorm_3d( + input_, + self.weight, + self.bias, + self.normalized_shape, + self.variance_epsilon, + self.output_parallel_mode, + self.input_x_weight_parallel_mode, + ) @LAYERS.register_module @@ -161,6 +184,7 @@ class Linear3D(ParallelLayer): out_features: int, bias: bool = True, dtype: torch.dtype = None, + skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -168,10 +192,12 @@ class Linear3D(ParallelLayer): self.out_features = out_features self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() - self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(out_features, self.depth**2) + self.skip_bias_add = skip_bias_add + self.in_features_per_partition = divide(in_features, self.depth**2) + self.out_features_per_partition = divide(out_features, self.depth) self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( @@ -194,18 +220,23 @@ class Linear3D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) + return grad + def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + self.bias.register_hook(self._sync_grad_hook) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -256,7 +287,7 @@ class Linear3D(ParallelLayer): local_state, self.weight_parallel_mode, dims={ - weight_key: -1, + weight_key: 0, bias_key: 0 }, partition_states={ @@ -279,7 +310,7 @@ class Linear3D(ParallelLayer): local_state, self.weight_parallel_mode, dims={ - weight_key: -1, + weight_key: 0, bias_key: 0 }, partition_states={ @@ -324,8 +355,20 @@ class Linear3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + output = linear_3d( + input_, + self.weight, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias @LAYERS.register_module @@ -360,7 +403,7 @@ class Classifier3D(ParallelLayer): self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) @@ -386,19 +429,17 @@ class Classifier3D(ParallelLayer): def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) + + register_async_grad_hook(self.weight) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) - broadcast(self.bias, input_src_rank, self.input_parallel_mode) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -468,8 +509,14 @@ class Classifier3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + return classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) @LAYERS.register_module @@ -504,7 +551,8 @@ class VocabParallelClassifier3D(ParallelLayer): self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) self.out_features_per_partition = divide(num_classes, self.depth**2) @@ -544,12 +592,14 @@ class VocabParallelClassifier3D(ParallelLayer): if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) + if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -668,8 +718,14 @@ class VocabParallelClassifier3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, - self.weight_parallel_mode, self.output_parallel_mode) + return vocab_parallel_classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) @LAYERS.register_module @@ -708,12 +764,16 @@ class PatchEmbedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) - self.patch_size = to_2tuple(patch_size) - grid_size = to_2tuple(img_size // patch_size) - num_patches = grid_size[0] * grid_size[1] + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] self.embed_size = embed_size - embed_size_per_partition = divide(embed_size, self.depth) + embed_size_per_partition = embed_size // self.depth self.flatten = flatten self.weight = nn.Parameter( @@ -725,7 +785,7 @@ class PatchEmbedding3D(ParallelLayer): self.cls_token = nn.Parameter( torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attributes() @@ -737,8 +797,7 @@ class PatchEmbedding3D(ParallelLayer): set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) def _sync_grad_hook(self, grad) -> Tensor: - grad = all_reduce(grad.clone(), self.input_parallel_mode) - grad = all_reduce(grad, self.weight_parallel_mode) + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) return grad def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: @@ -749,14 +808,10 @@ class PatchEmbedding3D(ParallelLayer): bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) - broadcast(self.weight, input_src_rank, self.input_parallel_mode) - broadcast(self.bias, input_src_rank, self.input_parallel_mode) - broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) + src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] + broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) self.weight.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook) @@ -850,8 +905,9 @@ class PatchEmbedding3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) - input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC @@ -906,7 +962,8 @@ class Embedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -924,13 +981,18 @@ class Embedding3D(ParallelLayer): def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + broadcast(self.weight, + gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + self.weight.register_hook(self._sync_grad_hook) def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: @@ -981,11 +1043,10 @@ class Embedding3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) - input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) - weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) - output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @@ -1039,7 +1100,7 @@ class VocabParallelEmbedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) self.embed_dim_per_partition = divide(self.embed_dim, self.depth) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py index 40129b7ec..3d767b888 100644 --- a/colossalai/nn/layer/vanilla/__init__.py +++ b/colossalai/nn/layer/vanilla/__init__.py @@ -1,6 +1,14 @@ -from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout, - WrappedDropPath) +from .layers import ( + DropPath, + VanillaClassifier, + VanillaLayerNorm, + VanillaLinear, + VanillaPatchEmbedding, + WrappedDropout, + WrappedDropPath, +) __all__ = [ - "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath" + "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", + "VanillaLinear" ] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py index a90871236..225aed391 100644 --- a/colossalai/nn/layer/vanilla/layers.py +++ b/colossalai/nn/layer/vanilla/layers.py @@ -1,290 +1,341 @@ -import math -from typing import Callable - -import torch -import torch.nn.functional as F -from colossalai.context import seed -from colossalai.nn import init as init -from colossalai.registry import LAYERS -from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch import nn as nn - -from ..utils import to_2tuple - - -def drop_path(x, drop_prob: float = 0., training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - - Args: - drop_prob (float, optional): probability of dropping path, defaults 0.0. - training (bool, optional): whether in training progress, defaults False. - """ - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py - - Args: - drop_prob (float, optional): probability of dropping path, defaults None. - """ - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class WrappedDropout(nn.Module): - r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes - some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each - channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of - 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): - super().__init__() - if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - self.p = p - self.inplace = inplace - if mode is None: - self.func = self.nonefunc - else: - self.func = self.normalfunc - self.mode = mode - - def nonefunc(self, inputs): - return F.dropout(inputs, self.p, self.training, self.inplace) - - def normalfunc(self, inputs): - with seed(self.mode): - return F.dropout(inputs, self.p, self.training, self.inplace) - - def forward(self, inputs): - return self.func(inputs) - - -class WrappedDropPath(nn.Module): - r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Here, it is wrapped with the context of seed manager. - - Args: - p (float, optional): probability of dropping path, defaults 0.0. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - def __init__(self, p: float = 0., mode=None): - super().__init__() - self.p = p - self.mode = mode - if self.mode is None: - self.func = self.nonefunc - else: - self.func = self.normalfunc - self.mode = mode - - def nonefunc(self, inputs): - return drop_path(inputs, self.p, self.training) - - def normalfunc(self, inputs): - with seed(self.mode): - return drop_path(inputs, self.p, self.training) - - def forward(self, inputs): - return self.func(inputs) - - -@LAYERS.register_module -class VanillaPatchEmbedding(nn.Module): - r""" - 2D Image to Patch Embedding - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - - self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) - self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) - - def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): - fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - bias_initializer(self.bias, fan_in=fan_in) - position_embed_initializer(self.pos_embed) - - def forward(self, input_: Tensor) -> Tensor: - B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) - if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - - cls_token = self.cls_token.expand(output.shape[0], -1, -1) - output = torch.cat((cls_token, output), dim=1) - output = output + self.pos_embed - return output - - -@LAYERS.register_module -class VanillaClassifier(nn.Module): - r"""Dense linear classifier. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) - self.has_weight = True - if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - - def reset_parameters(self, weight_initializer, bias_initializer): - fan_in, fan_out = self.in_features, self.num_classes - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tensor: - return F.linear(input_, self.weight, self.bias) - - -@LAYERS.register_module -class VanillaLayerNorm(nn.Module): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - super().__init__() - - self.normalized_shape = (normalized_shape,) - self.variance_epsilon = eps - - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - - self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) - if bias: - self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) - else: - self.bias = None - - def forward(self, x: Tensor) -> Tensor: - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) +import math +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn as nn +from torch.nn.parameter import Parameter + +from colossalai.context import seed +from colossalai.nn import init as init +from colossalai.registry import LAYERS +from colossalai.utils.cuda import get_current_device + +from ..utils import to_2tuple + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + Args: + drop_prob (float, optional): probability of dropping path, defaults 0.0. + training (bool, optional): whether in training progress, defaults False. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + + Args: + drop_prob (float, optional): probability of dropping path, defaults None. + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class WrappedDropout(nn.Module): + r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes + some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each + channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of + 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): + super().__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " + "but got {}".format(p)) + self.p = p + self.inplace = inplace + if mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def normalfunc(self, inputs): + with seed(self.mode): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def forward(self, inputs): + return self.func(inputs) + + +class WrappedDropPath(nn.Module): + r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Here, it is wrapped with the context of seed manager. + + Args: + p (float, optional): probability of dropping path, defaults 0.0. + mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0., mode=None): + super().__init__() + self.p = p + self.mode = mode + if self.mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return drop_path(inputs, self.p, self.training) + + def normalfunc(self, inputs): + with seed(self.mode): + return drop_path(inputs, self.p, self.training) + + def forward(self, inputs): + return self.func(inputs) + + +@LAYERS.register_module +class VanillaPatchEmbedding(nn.Module): + r""" + 2D Image to Patch Embedding + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def forward(self, input_: Tensor) -> Tensor: + B, C, H, W = input_.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + return output + + +@LAYERS.register_module +class VanillaClassifier(nn.Module): + r"""Dense linear classifier. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = nn.Parameter( + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) + self.has_weight = True + if bias: + self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tensor: + return F.linear(input_, self.weight, self.bias) + + +@LAYERS.register_module +class VanillaLayerNorm(nn.Module): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + super().__init__() + + self.normalized_shape = (normalized_shape,) + self.variance_epsilon = eps + + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) + + +@LAYERS.register_module +class VanillaLinear(nn.Module): + """Linear layer. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add: bool (optional, default to be false). + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + weight_initializer(self.weight, fan_in=in_features, fan_out=out_features) + if self.bias is not None: + bias_initializer(self.bias, fan_in=in_features) + + def forward(self, input: Tensor) -> Tensor: + if not self.skip_bias_add: + return F.linear(input, self.weight, self.bias) + else: + return F.linear(input, self.weight), self.bias diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index ea08ff723..54036973e 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,12 @@ import math +from typing import Optional + import torch +from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.registry import OPTIMIZERS + from .nvme_optimizer import NVMeOptimizer -from typing import Optional @OPTIMIZERS.register_module @@ -11,12 +14,12 @@ class CPUAdam(NVMeOptimizer): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of paramters. - But the parameters and gradients should on the same device: + But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on CPU is **not** allowed. - Requires ColossalAI to be installed via ``pip install .``. + `CPUAdam` requires CUDA extensions which can be built during installation or runtime. This version of CPU Adam accelates parameters updating on CPU with SIMD. Support of AVX2 or AVX512 is required. @@ -44,7 +47,7 @@ class CPUAdam(NVMeOptimizer): (default: False) NOT SUPPORTED yet in CPUAdam! adamw_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - simd_log (boolean, optional): whether to show if you are using SIMD to + simd_log (boolean, optional): whether to show if you are using SIMD to accelerate. (default: False) nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. @@ -74,10 +77,7 @@ class CPUAdam(NVMeOptimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - try: - import cpu_adam - except ImportError: - raise ImportError('Please install colossalai from source code to use CPUAdam') + cpu_adam = CPUAdamBuilder().load() self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) def torch_adam_update(self, @@ -114,7 +114,7 @@ class CPUAdam(NVMeOptimizer): data.addcdiv_(exp_avg, denom, value=-step_size) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None, div_scale: float = -1): loss = None if closure is not None: with torch.enable_grad(): @@ -149,9 +149,10 @@ class CPUAdam(NVMeOptimizer): self._pre_update(p, 'exp_avg', 'exp_avg_sq') self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], -1) + state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': + assert div_scale == -1, "div_scale should remain default" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 5814c28bd..941866d55 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. - Currently GPU-only. Requires ColossalAI to be installed via - ``pip install .``. + `FusedAdam` requires CUDA extensions which can be built during installation or runtime. This version of fused Adam implements 2 fusions. @@ -20,7 +19,7 @@ class FusedAdam(torch.optim.Optimizer): :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, or ``torch.optim.Adam`` with ``adamw_mode=False`` - :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. + :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. @@ -65,10 +64,12 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import colossal_C + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_adam = colossal_C.multi_tensor_adam + self.multi_tensor_adam = fused_optim.multi_tensor_adam else: raise RuntimeError('FusedAdam requires cuda extensions') @@ -80,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer): else: super(FusedAdam, self).zero_grad() - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1): """Performs a single optimization step. Arguments: @@ -136,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer): multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, - group['weight_decay']) + group['weight_decay'], div_scale) return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index be12e6c62..72520064e 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier class FusedLAMB(torch.optim.Optimizer): """Implements LAMB algorithm. - Currently GPU-only. Requires ColossalAI to be installed via - ``pip install .``. + `FusedLAMB` requires CUDA extensions which can be built during installation or runtime. This version of fused LAMB implements 2 fusions. @@ -76,13 +75,15 @@ class FusedLAMB(torch.optim.Optimizer): max_grad_norm=max_grad_norm) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import colossal_C - self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_lamb = colossal_C.multi_tensor_lamb + self.multi_tensor_lamb = fused_optim.multi_tensor_lamb else: raise RuntimeError('FusedLAMB requires cuda extensions') diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 1185eef81..468713b22 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier class FusedSGD(Optimizer): r"""Implements stochastic gradient descent (optionally with momentum). - Currently GPU-only. Requires ColossalAI to be installed via - ``pip install .``. + `FusedSGD` requires CUDA extensions which can be built during installation or runtime. This version of fused SGD implements 2 fusions. @@ -20,7 +19,7 @@ class FusedSGD(Optimizer): :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` - :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. + :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. @@ -80,12 +79,14 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - import colossal_C + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_sgd = colossal_C.multi_tensor_sgd + self.multi_tensor_sgd = fused_optim.multi_tensor_sgd else: raise RuntimeError('FusedSGD requires cuda extensions') diff --git a/colossalai/nn/optimizer/gemini_optimizer.py b/colossalai/nn/optimizer/gemini_optimizer.py new file mode 100644 index 000000000..31d161612 --- /dev/null +++ b/colossalai/nn/optimizer/gemini_optimizer.py @@ -0,0 +1,15 @@ +from typing import Any + +import torch + +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + +__all__ = ['GeminiAdamOptimizer'] + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 761843aab..1d0fb92de 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,8 +1,11 @@ +from typing import Any, Optional + import torch -from colossalai.utils import multi_tensor_applier +from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder from colossalai.registry import OPTIMIZERS -from typing import Optional +from colossalai.utils import multi_tensor_applier + from .nvme_optimizer import NVMeOptimizer @@ -11,12 +14,12 @@ class HybridAdam(NVMeOptimizer): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of paramters. - But the parameters and gradients should on the same device: + But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on CPU is **not** allowed. - Requires ColossalAI to be installed via ``pip install .`` + `HybriadAdam` requires CUDA extensions which can be built during installation or runtime. This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. @@ -43,7 +46,7 @@ class HybridAdam(NVMeOptimizer): (default: False) NOT SUPPORTED yet in CPUAdam! adamw_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - simd_log (boolean, optional): whether to show if you are using SIMD to + simd_log (boolean, optional): whether to show if you are using SIMD to accelerate. (default: False) nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. @@ -68,24 +71,23 @@ class HybridAdam(NVMeOptimizer): weight_decay=0, adamw_mode=True, nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None): + nvme_offload_dir: Optional[str] = None, + **defaults: Any): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - try: - import cpu_adam - import colossal_C - except ImportError: - raise ImportError('Please install colossalai from source code to use HybridAdam') - self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + # build during runtime if not found + cpu_optim = CPUAdamBuilder().load() + fused_optim = FusedOptimBuilder().load() + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - self.gpu_adam_op = colossal_C.multi_tensor_adam + self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None, div_scale: float = -1): loss = None if closure is not None: with torch.enable_grad(): @@ -122,7 +124,7 @@ class HybridAdam(NVMeOptimizer): self._pre_update(p, 'exp_avg', 'exp_avg_sq') self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], -1) + state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': @@ -142,6 +144,6 @@ class HybridAdam(NVMeOptimizer): bias_correction = 1 if group['bias_correction'] else 0 multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, - bias_correction, group['weight_decay']) + bias_correction, group['weight_decay'], div_scale) self._post_step() return loss diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py similarity index 78% rename from colossalai/zero/zero_optimizer.py rename to colossalai/nn/optimizer/zero_optimizer.py index aee8b2799..2786d4496 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -1,15 +1,20 @@ +import math +from enum import Enum +from typing import Any, Dict, Set, Tuple + import torch import torch.distributed as dist -from enum import Enum -from torch.optim import Optimizer from torch.nn import Parameter -from colossalai.nn.parallel.data_parallel import ZeroDDP -from typing import Dict, Tuple, Set +from torch.optim import Optimizer + from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import get_current_device, disposable from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.nn.parallel.data_parallel import ZeroDDP +from colossalai.utils import disposable, get_current_device + +_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} class OptimState(Enum): @@ -53,9 +58,13 @@ class ZeroOptimizer(ColossalaiOptimizer): backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, - max_scale: float = 2**32): + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) + assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list" self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager @@ -63,11 +72,17 @@ class ZeroOptimizer(ColossalaiOptimizer): self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + + if self.clipping_flag: + assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] for p, fp32_p in zip(params_list, module.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: + chunk_16.l2_norm_flag = self.clipping_flag self.chunk16_set.add(chunk_16) self.__init__optimizer() @@ -125,13 +140,49 @@ class ZeroOptimizer(ColossalaiOptimizer): return self._found_overflow.item() > 0 - def _unscale_grads(self): - assert self.optim_state == OptimState.SCALED - for group in self.optim.param_groups: - for p in group['params']: - if p.grad is not None: - p.grad.data.div_(self.loss_scale) - self.optim_state = OptimState.UNSCALED + def _calc_global_norm(self) -> float: + norm_sqr: float = 0.0 + group_to_norm = dict() + for c16 in self.chunk16_set: + assert c16.l2_norm is not None + + if c16.is_gathered: + norm_sqr += c16.l2_norm + else: + # this chunk is sharded, use communication to collect total norm + if c16.torch_pg not in group_to_norm: + group_to_norm[c16.torch_pg] = 0.0 + group_to_norm[c16.torch_pg] += c16.l2_norm + + c16.l2_norm = None # clear l2 norm + + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + for group, part_norm in group_to_norm.items(): + comm_buffer.fill_(part_norm) + dist.all_reduce(comm_buffer, group=group) + norm_sqr += comm_buffer.item() + + global_norm = math.sqrt(norm_sqr) + return global_norm + + def _get_combined_scale(self): + loss_scale = 1 + + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + if self.clipping_flag: + total_norm = self._calc_global_norm() + clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + if clip > 1: + combined_scale = clip * loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale @property def loss_scale(self): @@ -144,17 +195,22 @@ class ZeroOptimizer(ColossalaiOptimizer): def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._unscale_grads() + found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler self._logger.info(f'Found overflow. Skip step') - self.zero_grad() + self.zero_grad() # reset all gradients self._update_fp16_params() return - ret = self.optim.step(*args, **kwargs) + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() self.zero_grad() self._update_fp16_params() @@ -219,6 +275,8 @@ class ZeroOptimizer(ColossalaiOptimizer): def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_info = local_chunk.tensors_info[local_param] + if local_chunk.keep_gathered: + return param_info.offset, param_info.end begin = max(0, param_info.offset - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) return begin, end diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py index 9645e95f6..0c369bfce 100644 --- a/colossalai/nn/parallel/__init__.py +++ b/colossalai/nn/parallel/__init__.py @@ -1,3 +1,4 @@ from .data_parallel import ColoDDP, ZeroDDP +from .gemini_parallel import GeminiDDP -__all__ = ['ColoDDP', 'ZeroDDP'] +__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP'] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 5bce81708..a7d79be16 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -1,19 +1,24 @@ -import torch import itertools -import torch.distributed as dist -from functools import partial -from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 -from colossalai.tensor.param_op_hook import ParamOpHookManager -from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Dict, Iterable, List, Optional, Set -from colossalai.logging import get_dist_logger from collections import OrderedDict -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.tensor import ProcessGroup as ColoProcessGroup -from .reducer import Reducer +from functools import partial +from typing import Dict, Iterable, List, Optional, Set -from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager +import torch +import torch.distributed as dist + +from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import OrderedParamGenerator +from colossalai.logging import get_dist_logger from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device +from colossalai.zero.utils.gemini_hook import GeminiZeROHook + +from .reducer import Reducer +from .utils import get_static_torch_model try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -185,25 +190,16 @@ class ColoDDP(torch.nn.Module): class ZeroDDP(ColoDDP): - """ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now. - We can configure chunk and gemini via ChunkManager and GeminiManager respectively. + """ZeRO DDP for ColoTensor. + Warning: Nested ZeroDDP is not supported now. + It is designed to be used with ChunkManager and GeminiManager. For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. - Example: - >>> model = torch.nn.Linear(20, 1) - >>> placement_policy = 'cuda' - >>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None - >>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy)) - >>> gemini_manager = GeminiManager(placement_policy, chunk_manager) - >>> model = ZeroDDP(model, gemini_manager) - >>> logits = model(x) - >>> loss = criterion(logits, labels) - >>> model.backward(loss) - Args: module (torch.nn.Module): Module to apply ZeRO-DP. gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. For more details, see the API reference of ``GeminiManager``. + pin_memory (bool): Chunks on CPU Memory use pin-memory. force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. """ @@ -216,13 +212,24 @@ class ZeroDDP(ColoDDP): self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = ZeROHookV2(gemini_manager) + self.param_op_hook = GeminiZeROHook(gemini_manager) self.fp32_params: List[ColoTensor] = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} - # TODO: get param order and filter unused params - for p in module.parameters(): + cpu_offload = self.gemini_manager.policy_name != 'cuda' + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + for p in param_order.generate(): assert isinstance(p, ColoParameter) if getattr(p, '_ddp_to_ignore', False): @@ -232,28 +239,40 @@ class ZeroDDP(ColoDDP): fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) p.data = p.data.half() - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) - self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + self.chunk_manager.close_all_groups() self._cast_buffers() - params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] for p, fp32_p in zip(params_list, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32.init_pair(chunk_16) + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + self._logger = get_dist_logger() def forward(self, *args, **kwargs): args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True) - self.gemini_manager.pre_iter() - with ParamOpHookManager.use_hooks(self.param_op_hook): + self.gemini_manager.pre_iter(*args) + with ColoParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) if self.force_outputs_fp32: return _cast_float(outputs, torch.float) @@ -266,7 +285,9 @@ class ZeroDDP(ColoDDP): p.grad = None def _post_backward(self): - assert self.chunk_manager.accessed_mem == 0 + if self.chunk_manager.accessed_mem != 0: + raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with ZeroDDP.") self._setup_grads_ptr() self._logger.debug( f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' @@ -274,12 +295,12 @@ class ZeroDDP(ColoDDP): self.gemini_manager.post_iter() def backward(self, loss: torch.Tensor): - with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward() def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): torch.autograd.backward(tensor, grad) self._post_backward() @@ -287,16 +308,21 @@ class ZeroDDP(ColoDDP): empty_grad = torch.empty_like(grad) free_storage(empty_grad) with torch._C.DisableTorchFunction(): - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk = self.chunk_manager.get_chunk(p) + assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) if reduced: if chunk.is_gathered: - chunk.chunk_total.div_(chunk.pg_size) + chunk.cuda_global_chunk.div_(chunk.pg_size) else: chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad @@ -307,12 +333,10 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - r"""Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): + """ + Args: + strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()` Returns: dict: @@ -322,7 +346,30 @@ class ZeroDDP(ColoDDP): >>> module.state_dict().keys() ['bias', 'weight'] + """ + if strict: + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0) + return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + return self._non_strict_state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars, + only_rank_0=only_rank_0) + def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Warning: The non strict state dict would ignore the parameters if the tensors of the parameters + are shared with other parameters which have been included in the dictionary. + When you need to load the state dict, you should set the argument `strict` to False. + + Returns: + dict: + a dictionary containing a whole state of the module """ if destination is None: destination = OrderedDict() @@ -336,6 +383,35 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + """ + get param content from chunks. + + Args: + param_list (_type_): a list of torch.nn.Parameters + only_rank_0 (_type_): _description_ + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in param_to_save_data + param_to_save_data[tensor] = record_tensor + + del temp_chunk + return param_to_save_data + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every @@ -351,23 +427,8 @@ class ZeroDDP(ColoDDP): """ assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - # save parameters - param_to_save_data = dict() - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk - + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + # TODO: (HELSON) deal with ddp ignored parameters for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) @@ -519,7 +580,7 @@ class ZeroDDP(ColoDDP): load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) if chunk.is_gathered: - chunk.chunk_total.copy_(temp_chunk) + chunk.cuda_global_chunk.copy_(temp_chunk) elif chunk.cuda_shard is not None: chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) else: diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py new file mode 100644 index 000000000..cd5ef424a --- /dev/null +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch + +from colossalai.gemini.chunk import init_chunk_manager +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import MemStats + +from .data_parallel import ZeroDDP + + +class GeminiDDP(ZeroDDP): + + def __init__(self, + module: torch.nn.Module, + device: torch.device, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: Optional[float] = None, + memstats: Optional[MemStats] = None) -> None: + """ + A torch.Module warpper using ZeRO-DP and Genimi. + ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! + + Example: + model is initialized under the context of ColoInitContext + >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): the model to be wrapped. + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + """ + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 587339549..d323556d5 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -1,20 +1,112 @@ -import torch -import torch.distributed as dist -from colossalai.gemini.chunk import Chunk -from colossalai.utils import get_current_device - - -def get_temp_total_chunk_on_cuda(chunk: Chunk): - if chunk.is_gathered: - return chunk.chunk_total - - if chunk.cuda_shard is not None: - shard_temp = chunk.cuda_shard - else: - shard_temp = chunk.cpu_shard.to(get_current_device()) - - total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) - gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) - dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) - - return total_temp +from collections import OrderedDict +from copy import copy +from typing import Optional, Set + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.gemini.chunk import Chunk +from colossalai.utils import get_current_device + + +def get_temp_total_chunk_on_cuda(chunk: Chunk): + if chunk.is_gathered: + return chunk.cuda_global_chunk + + if chunk.cuda_shard is not None: + shard_temp = chunk.cuda_shard + else: + shard_temp = chunk.cpu_shard.to(get_current_device()) + + total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) + dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) + + return total_temp + + +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules. + """ + if memo is None: + memo = set() + if module not in memo: + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in _get_dfs_module_list(submodule, memo, submodule_prefix): + yield m + + memo.add(module) + yield prefix, module + + +def _get_shallow_copy_model(model: nn.Module): + """Get a shallow copy of the given model. Each submodule is different from the original submodule. + But the new submodule and the old submodule share all attributes. + """ + old_to_new = dict() + for name, module in _get_dfs_module_list(model): + new_module = copy(module) + new_module._modules = OrderedDict() + for subname, submodule in module._modules.items(): + if submodule is None: + continue + setattr(new_module, subname, old_to_new[submodule]) + old_to_new[module] = new_module + return old_to_new[model] + + +def get_static_torch_model(zero_ddp_model, + device=torch.device("cpu"), + dtype=torch.float32, + only_rank_0=True) -> torch.nn.Module: + """Get a static torch.nn.Module model from the given ZeroDDP module. + You should notice that the original ZeroDDP model is not modified. + Thus, you can use the original model in further training. + But you should not use the returned torch model to train, this can cause unexpected errors. + + Args: + zero_ddp_model (ZeroDDP): a zero ddp model + device (torch.device): the device of the final torch model + dtype (torch.dtype): the dtype of the final torch model + only_rank_0 (bool): if True, only rank0 has the coverted torch model + + Returns: + torch.nn.Module: a static torch model used for saving checkpoints or numeric checks + """ + from colossalai.nn.parallel import ZeroDDP + assert isinstance(zero_ddp_model, ZeroDDP) + + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) + colo_model = zero_ddp_model.module + torch_model = _get_shallow_copy_model(colo_model) + + if not only_rank_0 or dist.get_rank() == 0: + # record the mapping relationship between colo parameters and torch parameters + colo_to_torch = dict() + for (name, colo_module), (_, torch_module) in \ + zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + # clean the parameter list of the new torch module + torch_module._parameters = OrderedDict() + for sufix_param_name, param in colo_module.named_parameters(recurse=False): + # get the full name of the parameter + full_param_name = name + ('.' if name else '') + sufix_param_name + + if full_param_name not in state_dict: + # this means the parameter is shared by multiple modules + # we should use colo_to_torch to get the torch parameter created before + assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module" + torch_param = colo_to_torch[param] + else: + # we meet the parameter the first time, just use the state dict to get the data + state_param = state_dict[full_param_name] + torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) + colo_to_torch[param] = torch_param + + setattr(torch_module, sufix_param_name, torch_param) + dist.barrier() + + return torch_model diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 625bd7ef5..0fcde9707 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,4 +1,4 @@ from .pipelinable import PipelinableContext, PipelinableModel -from .layer_sepc import LayerSpec +from .layer_spec import LayerSpec __all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file diff --git a/colossalai/pipeline/layer_sepc.py b/colossalai/pipeline/layer_spec.py similarity index 100% rename from colossalai/pipeline/layer_sepc.py rename to colossalai/pipeline/layer_spec.py diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py new file mode 100644 index 000000000..79e19f9ea --- /dev/null +++ b/colossalai/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal + +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py new file mode 100644 index 000000000..949700a2c --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/__init__.py @@ -0,0 +1,3 @@ +from .fx import get_topology as get_fx_topology + +__all__ = ['get_fx_topology'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py new file mode 100644 index 000000000..8437c5194 --- /dev/null +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -0,0 +1,145 @@ +from torch.fx.graph_module import GraphModule +from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo +import torch + +def partition_name_to_id(partition_name, is_input=False, is_output=False): + if is_input: + partition_id = 0 + elif is_output: + partition_id = 1 + else: + prefix = 'submod_' + partition_id = int(partition_name.split(prefix)[-1]) + 2 + return partition_id + +# There are two kinds of def in fx.graph +# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. +# e.g. submod1 = call_module(...) +# temporary_val = submod1[0] +# submod2 = call_module(temporary_val, ...) +# 2. direct_use & direct_def, which means the output is used by next partition directly. +# e.g. submod1 = call_module(...) +# submod2 = call_module(submod1, ...) +def find_input_in_partition(node, partitions, input_partitions=None): + p_input_val = None + direct_def = not node.name.startswith('getitem') + # search in input + if direct_def and input_partitions is not None: + partition_id = partition_name_to_id('', is_input=True) + for i, input_node in enumerate(input_partitions): + if input_node == node: + p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) + return p_input_val + # search submod in mid part + if direct_def: + for partition in partitions: + if partition == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=0) + return p_input_val + # search temporary value in graph + else: + for partition in partitions: + for offset, mid_val in enumerate(partition.users): + if mid_val == node: + partition_id = partition_name_to_id(partition.name) + p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) + return p_input_val + + return p_input_val + +def find_output_in_partition(node, partitions, output_partitions=None): + p_output_val = PartitionOutputVal() + for user in node.users: + direct_use = not user.name.startswith('getitem') + # user is mid partition + for partition in partitions: + # direct call + if direct_use: + if user == partition: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + # getitem call + else: + if user in partition.args: + partition_id = partition_name_to_id(partition.name) + for i, arg in enumerate(partition.args): + if arg == user: + p_output_val.add(partition_id=partition_id, offset=i) + break + + # user is output + if output_partitions is not None: + output_node = output_partitions[0] + if user.op == output_node.op: + output_keys = {} + partition_id = partition_name_to_id('', is_output=True) + torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) + for i, arg in enumerate(output_keys): + if arg == node: + p_output_val.add(partition_id=partition_id, offset=i) + break + return p_output_val + +def get_topology(gm: GraphModule): + topo = Topo() + topo_output_partition = Partition() + + input_partitions = [] + partitions = [] + output_partitions = [] + for node in gm.graph.nodes: + if node.op == 'placeholder': + input_partitions.append(node) + elif node.name.startswith('submod_'): + partitions.append(node) + elif node.op == 'output': + output_partitions.append(node) + else: + continue + + # set output for input_partition + topo_input_partition = Partition() + for partition in input_partitions: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_input_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=0, partition=topo_input_partition) + topo.set_input_partition_id(partition_id=0) + + for i, partition in enumerate(partitions): + topo_mid_partition = Partition() + # set input for submodule + for arg in partition.args: + cur_node = arg + p_input_val = find_input_in_partition(cur_node, partitions, input_partitions) + topo_mid_partition.add_input_val(p_input_val) + # set output for submodule + direct_use = True + for user in partition.users: + if user.name.startswith('getitem'): + direct_use = False + break + if direct_use: + cur_node = partition + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + else: + for user in partition.users: + cur_node = user + p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) + + # set input for output_partition + for partition in output_partitions: + topo_output_partition = Partition() + torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( + find_input_in_partition(n, partitions, input_partitions))) + topo.set_partitions(partition_id=1, partition=topo_output_partition) + topo.set_output_partition_id(partition_id=1) + + return topo \ No newline at end of file diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py new file mode 100644 index 000000000..e798e2ed9 --- /dev/null +++ b/colossalai/pipeline/middleware/topo.py @@ -0,0 +1,206 @@ +from typing import Dict, List +from dataclasses import dataclass + +# This file includes data structure used by Pipeline Middleware. + +@dataclass +class ValPosition: + partition_id: int + offset: int + + def __str__(self) -> str: + res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: + # every input from which partition_id and which offset + val_pos = ValPosition(partition_id, offset) + self._from_partition_and_offset: ValPosition = val_pos + + def get(self): + return self._from_partition_and_offset + + def __str__(self) -> str: + res = '' + res += f'<-({self._from_partition_and_offset})' + return res + + def __repr__(self) -> str: + return self.__str__() + +class PartitionOutputVal(object): + def __init__(self) -> None: + # every output to which partition_id and which offset + self._to_partition_and_offset: List[ValPosition] = [] + + def add(self, partition_id, offset): + val_pos = ValPosition(partition_id, offset) + self._to_partition_and_offset.append(val_pos) + + def get(self): + return self._to_partition_and_offset + + def __str__(self) -> str: + res = '' + res += '->(' + for val_pos in self._to_partition_and_offset: + res += f'{val_pos},' + res += ')' + return res + + def __repr__(self) -> str: + return self.__str__() + +class Partition(object): + def __init__(self) -> None: + self._input_vals: List[PartitionInputVal] = [] + self._output_vals: List[PartitionOutputVal] = [] + + def add_input_val(self, input_val: PartitionInputVal): + self._input_vals.append(input_val) + + def add_output_val(self, output_val: PartitionOutputVal): + self._output_vals.append(output_val) + + def get_input_vals(self): + return self._input_vals + + def get_output_vals(self): + return self._output_vals + + # get the output offsets sent to dst_partition_id + def get_output_offsets(self, dst_partition_id): + res = [] + for offset, output_val in enumerate(self._output_vals): + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id == dst_partition_id: + res.append(offset) + + return res + + # get all input dst partition_ids + def get_input_partition_ids(self): + res = [] + for input_val in self._input_vals: + val_pos = input_val.get() + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + # get all output dst partition_ids + def get_output_partition_ids(self): + res = [] + for output_val in self._output_vals: + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + def __str__(self) -> str: + res = '' + res += f' input:\n' + res += f' length:{len(self._input_vals)}\n' + for i, input_val in enumerate(self._input_vals): + res += f' offset={i}:{input_val}\n' + + res += f' output:\n' + res += f' length:{len(self._output_vals)}\n' + for i, output_val in enumerate(self._output_vals): + res += f' offset={i}:{output_val}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + +# This class is a middleware between partition splitter +# and Pipeline Scheduler. It records the graph info about +# partition input/output and provides it to scheduler. +# There are three kinds of partition in Pipeline Middleware Design +# which represents the whole process of a model execution: input-fwd-output +# 1. input_partition: records the input of a model. +# 2. mid_partition: record the splitted forwards execution of a model. +# 3. output_partition: records the output of a model. +# attributes: +# _partitions: include all partitions +# _input_partition_id: the key represents input_partition +# _output_partition_id: the key represents output_partition +class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: + self._partitions: Dict[int, Partition] = {} + self._input_partition_id = input_partition_id + self._output_partition_id = output_partition_id + + def set_input_partition_id(self, partition_id: int): + self._input_partition_id = partition_id + + def set_output_partition_id(self, partition_id: int): + self._output_partition_id = partition_id + + def get_input_partition_id(self): + return self._input_partition_id + + def get_output_partition_id(self): + return self._output_partition_id + + def set_partitions(self, partition_id: int, partition: Partition): + self._partitions[partition_id] = partition + + def get_mid_partitions(self): + res = {} #{partition_id: Partition} + for partition_id, partition in self._partitions.items(): + if self._input_partition_id == partition_id or self._output_partition_id == partition_id: + continue + res[partition_id] = partition + return res + + def get_mid_partition_ids(self): + return list(self.get_mid_partitions().keys()) + + def get_input_partition(self): + if self._input_partition_id is not None: + return self._partitions[self._input_partition_id] + return None + + def get_output_partition(self): + if self._output_partition_id is not None: + return self._partitions[self._output_partition_id] + return None + + def get_partition_by_id(self, partition_id): + return self._partitions[partition_id] + + def __str__(self) -> str: + res = '' + if len(self._partitions) == 0: + return 'Empty Topo Graph.' + + input_part = self.get_input_partition() + if input_part is not None: + res += '{\n' + res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' + res += '}\n' + + mid_parts = self.get_mid_partitions() + for i, (partition_id, part) in enumerate(mid_parts.items()): + res += '{\n' + res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' + res += '}\n' + + output_part = self.get_output_partition() + if output_part is not None: + res += '{\n' + res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' + res += '}\n' + + return res + + def __repr__(self) -> str: + return self.__str__() + \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index 4d37c9833..9731530a6 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -9,7 +9,7 @@ from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter from colossalai.core import global_context as gpc from colossalai.context import ParallelMode -from .layer_sepc import LayerSpec +from .layer_spec import LayerSpec class PipelinableContext(InsertPostInitMethodToModuleSubClasses): diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 830e2bf2d..4739cdaa9 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -8,18 +8,28 @@ from typing import Any, Callable, Dict, List, Tuple import torch import torch.distributed.rpc as rpc -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map, - split_batch, tensor_shape_list, type_detail) from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef from torch.futures import Future +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc.utils import ( + get_batch_lengths, + pyobj_map, + pytree_filter, + pytree_map, + split_batch, + tensor_shape_list, + type_detail, +) + class Phase(Enum): FORWARD = 0 BACKWARD = 1 UPDATE = 2 + INPUT = 3 class UniqueKey: @@ -134,6 +144,7 @@ class WorkerBase(ABC): self.partition_args = partition_args self.criterion = criterion self.metric = metric + self.reset = False # context to maintain loop self._initialize_context_container() @@ -164,6 +175,7 @@ class WorkerBase(ABC): self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock()) + self.reset_condition = threading.Condition(threading.Lock()) def _initialize_partition(self): partition_fn = self.partition_fn @@ -173,6 +185,41 @@ class WorkerBase(ABC): self.module_partition: nn.Module = partition_fn(*partition_args).to(device) self.partition_condition_lock.notify_all() + def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None): + with self.output_list_condition_lock: + self.output_list_condition_lock.wait_for(lambda: key in self.output_list) + output_work_item = self.output_list[key] + output = output_work_item.output + if not ref_use and output_work_item.phase != Phase.INPUT: + self.output_list.pop(key) + + if not ref_use and output_work_item.phase != Phase.INPUT: + output_work_item.refcount += 1 + refcount = output_work_item.refcount + # lifecycle management for DAG scheduler + if output_work_item.phase == Phase.FORWARD: + lifecycle = len(self.get_consumer_stage_ids()) + if self.is_model_output(): # an extra reference for scheduler collecting results + lifecycle += 1 + elif output_work_item.phase == Phase.BACKWARD: + lifecycle = len(self.get_producer_stage_ids()) + if self.is_model_input() and self._is_last_step( + output_work_item): # an extra reference for ensure_backward + lifecycle += 1 + else: + lifecycle = 0 + refcount = 0 + + with self.output_list_condition_lock: + if refcount < lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + + if isinstance(output, Future): + output = output.wait() + + return output + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" @@ -182,23 +229,21 @@ class WorkerBase(ABC): # construction of partition is executed after the registion of pp_rank_to_worker_rref self._initialize_partition() - def get_output_by_key(self, key: UniqueKey) -> Any: - with self.output_list_condition_lock: - self.output_list_condition_lock.wait_for(lambda: key in self.output_list) - output_work_item = self.output_list[key] - - output = output_work_item.output - if isinstance(output, Future): - output = output.wait() - - output_work_item.refcount += 1 - - # all consumers have been satisfied, the work_item can be released - with self.output_list_condition_lock: - if output_work_item.refcount >= len(self.consumer_stage_ids): - self.output_list.pop(key) + # res_use works for lifecycle counter, + # if ref_use is True, lifecycle won't add. + # offset supports get partial output to reduce comm costs. + def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: + output = self._get_output_all(key, ref_use, rank) + if offsets is None: # get all for non iterable output + return output + else: # get part for iterable output + output = [output[i] for i in offsets] return output + def get_numels(self) -> int: + numel = sum(param.numel() for param in self.module_partition.parameters()) + return numel + def get_parameters(self) -> List[torch.Tensor]: return [p for p in self.module_partition.parameters()] @@ -215,8 +260,10 @@ class WorkerBase(ABC): self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) return self.module_partition.state_dict() - def _make_args_kwargs(self, microbatch): + def _make_args_kwargs(self, microbatch, merge=False): if isinstance(microbatch, dict): + if merge: + return list(microbatch.values()), {} return [], microbatch elif isinstance(microbatch, torch.Tensor): return [microbatch], {} @@ -228,24 +275,58 @@ class WorkerBase(ABC): kwargs.update(arg) else: args.append(arg) + if merge: + arg_lst = args + for arg in kwargs.values(): + arg_lst.append(arg) + return arg_lst, {} return args, kwargs else: raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") # just for first pp_rank def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): - assert self.consumer_stage_ids is not None key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() - # make args and kwargs - args, kwargs = self._make_args_kwargs(microbatch) + if not self.use_middleware(): + # make args and kwargs + args, kwargs = self._make_args_kwargs(microbatch) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) - with self.work_list_condition_lock: - self.work_list[key] = work_item - self.work_list_condition_lock.notify_all() + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + else: + # make args and kwargs + arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) + + # first stage assign correct input into other stages + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + input_partition = topo.get_input_partition() + self_input_offsets = input_partition.get_output_offsets(self_partition_id) + recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) + + # set input for self rank + self_arg_lst = [] + for off in self_input_offsets: + self_arg_lst.append(arg_lst[off]) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + # put input tensor which other nodes need into output_list as Phase.INPUT + work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, + self.num_microbatches, forward_only) + + with self.output_list_condition_lock: + self.output_list[recv_input_key] = work_item_remote + self.output_list_condition_lock.notify_all() # just for last pp_rank def set_labels(self, microbatch_id: int, microlabels: Any): @@ -268,62 +349,159 @@ class WorkerBase(ABC): self.work_list[key] = work_item self.work_list_condition_lock.notify_all() - def subscribe_producer(self, microbatch_id: int, forward_only: bool): + def _subscribe_producer(self, microbatch_id: int, forward_only: bool): """ You should call this function asynchronously """ - assert self.producer_stage_ids is not None - producer_num = len(self.producer_stage_ids) - assert producer_num > 0, "only stage that has producers can subscribe producers" - stage_id = self.pp_rank - subscribe_forward_futures: List[Future] = [None] * producer_num output = self._get_future_by_device() + if not self.use_middleware(): + producer_num = len(self.producer_stage_ids) + subscribe_forward_futures: List[Future] = [None] * producer_num + for i in range(producer_num): + producer_stage_id = self.producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + else: + producer_stage_ids = self.get_producer_stage_ids() + producer_num = len(producer_stage_ids) + if self.need_model_input(): + producer_num += 1 # for input partition + subscribe_forward_futures: List[Future] = [None] * producer_num - for i in range(producer_num): - producer_stage_id = self.producer_stage_ids[i] - producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) - producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + # TODO(jiangziyue) get single value instead of the whole output + if self.need_model_input(): + producer_stage_id = 0 + producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + offsets = self._get_input_offsets_by_index(target_index=0) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, + rank=self.pp_rank, + offsets=offsets) + + for i in range(0, producer_num - 1): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + target_index = i + 1 + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank) + + else: + for i in range(producer_num): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + target_index = i + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets) work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, microbatch_id, None, self.num_microbatches, forward_only) - # add work_item to work_list - with self.work_list_condition_lock: - key = UniqueKey(microbatch_id, Phase.FORWARD) - assert key not in self.work_list - self.work_list[key] = work_item_from_producer - self.work_list_condition_lock.notify_all() + return work_item_from_producer - def subscribe_consumer(self, microbatch_id: int): + # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. + def subscribe_producer(self, microbatch_id: int, forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) + with self.work_list_condition_lock: + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only) + self.work_list[key] = work_item_from_producer + self.work_list_condition_lock.notify_all() + + def _subscribe_consumer(self, microbatch_id: int): """ You should call this function asynchronously """ - assert self.producer_stage_ids is not None - consumer_num = len(self.consumer_stage_ids) - assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" - stage_id = self.pp_rank - subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() - + if not self.use_middleware(): + consumer_stage_ids = self.consumer_stage_ids + else: + consumer_stage_ids = self.get_consumer_stage_ids() + consumer_num = len(consumer_stage_ids) + subscribe_backward_futures: List[Future] = [None] * consumer_num for i in range(consumer_num): - consumer_stage_id = self.consumer_stage_ids[i] + consumer_stage_id = consumer_stage_ids[i] consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] - subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) + target_index = i + offsets = self._get_output_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_backward_futures[target_index] = [] + else: + subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( + consumer_output_key, rank=self.pp_rank, offsets=offsets) # flatten args work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, microbatch_id, None, self.num_microbatches, False) - # add work_item to work_list + return work_item_from_consumer + + def subscribe_consumer(self, microbatch_id: int): + key = UniqueKey(microbatch_id, Phase.BACKWARD) with self.work_list_condition_lock: - key = UniqueKey(microbatch_id, Phase.BACKWARD) - assert key not in self.work_list - self.work_list[key] = work_item_from_consumer - self.work_list_condition_lock.notify_all() + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_consumer = self._subscribe_consumer(microbatch_id) + self.work_list[key] = work_item_from_consumer + self.work_list_condition_lock.notify_all() + + def get_producer_stage_ids(self): + producer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + prev_rank = rank - 1 + if prev_rank >= 0: + producer_stage_ids.append(prev_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + input_partition_ids = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + for partition_id in input_partition_ids: + # ignore input partition in current implementation. + # it will be specially tackled. + if partition_id != model_input_partition_id: + producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return producer_stage_ids + + def get_consumer_stage_ids(self): + consumer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + next_rank = rank + 1 + if next_rank <= self.actual_stage_num - 1: + consumer_stage_ids.append(next_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_partition_ids = self_partition.get_output_partition_ids() + model_output_partition_id = topo.get_output_partition_id() + for partition_id in output_partition_ids: + if model_output_partition_id != partition_id: + consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return consumer_stage_ids def _get_producer_consumer(self) -> None: rank = self.pp_rank @@ -331,16 +509,212 @@ class WorkerBase(ABC): assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" # should be aranged in order, the order of the input of current forward - self.producer_stage_ids = [] - self.consumer_stage_ids = [] + self.producer_stage_ids = self.get_producer_stage_ids() + self.consumer_stage_ids = self.get_consumer_stage_ids() - # Just for demo - prev_rank = rank - 1 - next_rank = rank + 1 - if prev_rank >= 0: - self.producer_stage_ids.append(prev_rank) - if next_rank <= self.actual_stage_num - 1: - self.consumer_stage_ids.append(next_rank) + def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + return partition_ids[pp_rank] + + def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + for i, id in enumerate(partition_ids): + if id == partition_id: + return i + + def get_topo(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + if hasattr(self.module_partition, '_topo'): + return self.module_partition._topo + else: + return None + + def use_middleware(self): + topo = self.get_topo() + return topo is not None + + def _get_input_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if target_index == src_index: + if output_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(src_offset) + return res + + def _get_output_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if target_index == dst_index: + if input_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(dst_offset) + return res + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_fwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + if self.is_first_stage(): + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + else: # get by offset + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if output_len == 1: + target = args_or_kwargs[src_index] + else: + offsets = self._get_input_offsets_by_index(src_index) + real_offset = offsets.index(src_offset) + target = args_or_kwargs[src_index][real_offset] + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_bwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + for i, arg in enumerate(args_or_kwargs): + args_or_kwargs[i] = arg.wait() + if args_or_kwargs is not None: # get by offset + flatten_args = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if input_len == 1: + part_grad = args_or_kwargs[dst_index] + else: + offsets = self._get_output_offsets_by_index(dst_index) + real_offsets = offsets.index(dst_offset) + part_grad = args_or_kwargs[dst_index][real_offsets] + + if target is None: + target = part_grad + elif part_grad is not None: + target += part_grad + else: + continue + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs @abstractmethod def _get_work_item_key(self) -> UniqueKey: @@ -354,6 +728,23 @@ class WorkerBase(ABC): def is_last_stage(self): return self.pp_rank == self.actual_stage_num - 1 + def need_model_input(self): + need_input = False + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition = topo.get_partition_by_id(self_partition_id) + partition_inputs = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + if model_input_partition_id in partition_inputs: + need_input = True + return not self.is_first_stage() and need_input + + def is_model_output(self): + return self.is_last_stage() + + def is_model_input(self): + return self.is_first_stage() + def _default_data_process_func(self, args_kwargs): if self.is_first_stage(): args = args_kwargs[0] @@ -390,11 +781,16 @@ class WorkerBase(ABC): # parse and integrate args and kwargs if is_first_stage: - args = get_real_args_kwargs(args) - kwargs = get_real_args_kwargs(kwargs) + args = self._get_real_args_kwargs_fwd(args) + kwargs = self._get_real_args_kwargs_fwd(kwargs) args_kwargs = (args, kwargs) else: - args_kwargs = get_real_args_kwargs(args) + args_kwargs = self._get_real_args_kwargs_fwd(args) + + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, + process_types=torch.device) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs) @@ -459,6 +855,9 @@ class WorkerBase(ABC): stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint) + consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in + # if not forward_only, do the backward if not forward_only: if is_last_stage: # if it is the last stage, trigger backward automatic @@ -486,21 +885,43 @@ class WorkerBase(ABC): # overlap recompute and future.wait if not is_last_stage: - grad_tensors = get_real_args_kwargs(args) + grad_tensors = self._get_real_args_kwargs_bwd(args) else: grad_tensors = None # take tensor only (for only tensor can do backward) - stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor) - grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor) + # TODO(jiangziyue) : All values which should do bp are torch.Tensor? + stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor) + grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor) + # output all input's grad to producer, even it has no grad(output None) + # to make the offset aligned to the topo's record. + if grad_tensors is not None: + filtered_outputs = [] + filtered_grads = [] + for i, grad in enumerate(grad_tensors): + stage_output = stage_outputs[i] + if stage_output.requires_grad and grad is not None: + filtered_outputs.append(stage_output) + filtered_grads.append(grad) + + stage_outputs = filtered_outputs + grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor consume_result = [] if not is_first_stage: - pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) - pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) + # In current design, input mush be a flatten args. + for arg in stage_input_args: + if isinstance(arg, torch.Tensor): + consume_result.append(arg.grad) + else: + consume_result.append(None) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -532,11 +953,11 @@ class WorkerBase(ABC): def _hook_before_step(self): pass - def _reset_context(self): - self.forward_times = 0 - self.backward_times = 0 - self.outstanding = 0 - self._initialize_outstanding_range() + # install the main loop to wait for next batch input + def _wait_for_reset(self): + with self.reset_condition: + self.reset_condition.wait_for(lambda: self.reset) + self.reset = False # do the main loop to consume ready_list def _work_loop(self): @@ -547,10 +968,10 @@ class WorkerBase(ABC): # main loop while True: work_item_key = self._get_work_item_key() - # move current work item to output_list to activate subscribe in advance with self.work_list_condition_lock: - work_item = self.work_list.pop(work_item_key) + self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) + work_item = self.work_list[work_item_key] with self.output_list_condition_lock: # assert work_item_key not in self.output_list @@ -559,27 +980,37 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) + with self.work_list_condition_lock: + self.work_list.pop(work_item_key) work_item.output.set_result(consume_result) # if is last step in one batch reset context and do step if self._is_last_step(work_item): - self._hook_before_step() - if hasattr(self, 'optimizer') and not work_item.forward_only: - self.step() - self._reset_context() + self._wait_for_reset() + + # reset context and resume loop + def reset_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + with self.work_list_condition_lock: + self.work_list.clear() + + with self.output_list_condition_lock: + self.output_list.clear() + + with self.reset_condition: + self.reset = True + self.reset_condition.notify_all() def initialize_optimizer(self, optimizer_class: type, **kwargs): self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) - self.step_lock = threading.Lock() - self.step_lock.acquire() - - def wait_for_step(self): - self.step_lock.acquire() def step(self): + self._hook_before_step() self.optimizer.step() self.optimizer.zero_grad() - self.step_lock.release() class PipelineEngineBase(ABC, nn.Module): @@ -611,8 +1042,6 @@ class PipelineEngineBase(ABC, nn.Module): self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() - self.step_futs: List[Future] = [] - self._check_argument() self._create_pp_rank_to_rpc_worker_id() self._create_pp_rank_to_module_partition_id() @@ -639,7 +1068,7 @@ class PipelineEngineBase(ABC, nn.Module): def _create_pp_rank_to_rpc_worker_id(self) -> None: """create a map from model partition to stage_id, which is useful when use_interleave is True. - e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then + e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part of partitions will be moved to device 0 and the others to device 1 """ @@ -692,6 +1121,15 @@ class PipelineEngineBase(ABC, nn.Module): for fut in sync_futs: fut.wait() + def remote_numels(self) -> Dict[int, int]: + numels = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[stage_id] + numel = worker_rref.rpc_sync().get_numels() + numels[stage_id] = numel + return numels + def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: parameters = {} actual_stage_num = self._get_actual_stage_num() @@ -728,9 +1166,14 @@ class PipelineEngineBase(ABC, nn.Module): ret_future[pp_rank][microbatch_id - actual_stage_num].wait() else: key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + futs = [] for pp_rank in input_pp_ranks: worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().get_output_by_key(key) + fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[]) + futs.append(fut) + + for fut in futs: + fut.wait() def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: num_microbatches = self.num_microbatches @@ -748,6 +1191,7 @@ class PipelineEngineBase(ABC, nn.Module): # TODO : add relationship between output_pp_ranks and parts of microlabels worker_rref.remote().set_labels(microbatch_id, microlabels) + # TODO(jiangziyue) : get model output with single value, instead of merging into last stage. def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): key = UniqueKey(microbatch_id, Phase.FORWARD) for pp_rank in output_pp_ranks: @@ -756,10 +1200,16 @@ class PipelineEngineBase(ABC, nn.Module): def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): if not forward_only: + backward_result = [] for pp_rank in input_pp_ranks: worker_rref = self.pp_rank_to_worker_rref[pp_rank] key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) - worker_rref.rpc_sync().get_output_by_key(key) + fut = worker_rref.rpc_async().get_output_by_key( + key, offsets=[]) # only ensure the res exists, no need for real data. + backward_result.append(fut) + + for fut in backward_result: + fut.wait() def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): forward_result = [] @@ -776,6 +1226,17 @@ class PipelineEngineBase(ABC, nn.Module): return forward_result + def _reset_worker(self): + actual_stage_num = self._get_actual_stage_num() + reset_futs: List[Future] = [] + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().reset_context() + reset_futs.append(fut) + + for fut in reset_futs: + fut.wait() + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): batch_lengths = get_batch_lengths(batch) batch_length = batch_lengths[0] @@ -800,7 +1261,7 @@ class PipelineEngineBase(ABC, nn.Module): for microbatch_id in range(num_microbatches): # control data input speed # to prevent exceed of wait limitations - self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) + # self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) batch_start = microbatch_size * microbatch_id batch_end = min(batch_start + microbatch_size, batch_length) @@ -824,11 +1285,9 @@ class PipelineEngineBase(ABC, nn.Module): forward_result = self._collect_forward_result(output_pp_ranks, ret_future) if not forward_only and hasattr(self, 'optimizer_class'): - # wait for all step - for pp_rank in self.pp_rank_to_worker_rref: - worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().wait_for_step() + self.step() + self._reset_worker() # reset worker attributes for next batch return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): @@ -839,10 +1298,11 @@ class PipelineEngineBase(ABC, nn.Module): def step(self): actual_stage_num = self._get_actual_stage_num() + step_futs: List[Future] = [] for pp_rank in range(actual_stage_num): worker_rref = self.pp_rank_to_worker_rref[pp_rank] fut = worker_rref.rpc_async().step() - self.step_futs.append(fut) + step_futs.append(fut) - for fut in self.step_futs: + for fut in step_futs: fut.wait() diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 0ab3a3694..e6aa961f1 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -3,11 +3,12 @@ from typing import Callable, Dict, List import torch import torch.distributed as dist -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem) from torch._C._distributed_rpc import PyRRef from torch.futures import Future +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem + # Implementation of different Pipeline schedule # Worker defines the worker for each stage # PipelineEngine is the class for use @@ -86,12 +87,9 @@ class OneFOneBWorker(WorkerBase): outstanding_min = actual_stage_num - pp_rank - 1 outstanding_max = actual_stage_num - pp_rank self.outstanding_range = (outstanding_min, outstanding_max) - elif target_key.microbatch_id == num_microbatches - 1: + if target_key.microbatch_id == num_microbatches - 1: self.outstanding_range = (0, 0) - with self.work_list_condition_lock: - self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - return target_key diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index 361f6faf7..06e6d976d 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from colossalai.initialize import launch -from colossalai.pipeline.pipeline_process_group import ppg from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.futures import Future +from colossalai.initialize import launch +from colossalai.pipeline.pipeline_process_group import ppg + + +def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: + if isinstance(obj, process_types): + return fn(obj) + elif type(obj) is dict: + return {k: pyobj_map(obj[k], fn, process_types) for k in obj} + elif type(obj) is tuple: + return tuple(pyobj_map(o, fn, process_types) for o in obj) + elif type(obj) is list: + return list(pyobj_map(o, fn, process_types) for o in obj) + else: + return obj + def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: """process object recursively, like pytree @@ -19,10 +33,10 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = obj (:class:`Any`): object to process fn (:class:`Callable`): a function to process subobject in obj process_types (:class: `type | tuple[type]`): types to determine the type to process - map_all (:class: `bool`): if map_all is True, then any type of element will use fn + map_all (:class: `bool`): if map_all is True, then any type of element will use fn Returns: - :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` """ if isinstance(obj, dict): return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} @@ -137,5 +151,5 @@ def parse_args(): parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--master_addr', type=str, default='localhost') parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument('--num_worker_threads', type=int, default=128) return parser.parse_args() diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py index 5afed0225..df7226644 100644 --- a/colossalai/pipeline/utils.py +++ b/colossalai/pipeline/utils.py @@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.layer.utils import CheckpointModule from typing import List +from collections import OrderedDict def _binary_partition(weights: List, start: int, end: int): """Returns the binary partition position of `weights`, given the start @@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): kwargs_offset = 0 elif isinstance(input_tensor, torch.Tensor): kwargs_offset = 1 - else: - assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + elif isinstance(input_tensor, (tuple, OrderedDict)): + #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # Huggingface will take their own structures based on OrderedDict as the output + # between layers so we've to close this check. kwargs_offset = len(input_tensor) args_name_list = list(sig.parameters.keys()) kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 4946d7077..b2da64e6c 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,19 +1,18 @@ +from . import distspec +from .colo_parameter import ColoParameter +from .colo_tensor import ColoTensor +from .comm_spec import CollectiveCommPattern, CommSpec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager from .process_group import ProcessGroup from .tensor_spec import ColoTensorSpec -from .distspec import ShardSpec -from .distspec import ReplicaSpec - -from .compute_spec import ComputeSpec, ComputePattern -from .colo_tensor import ColoTensor -from .colo_parameter import ColoParameter -from .utils import convert_parameter, named_params_with_colotensor -from .dist_spec_mgr import DistSpecManager -from .param_op_hook import ParamOpHook, ParamOpHookManager -from .comm_spec import CollectiveCommPattern, CommSpec -from . import distspec +from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', - 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern' + 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', + 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'merge_same_dim_mesh_list' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 17c326516..92220d9e2 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -1,15 +1,32 @@ -import torch - from typing import Optional +import torch + from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.const import TensorType -from colossalai.tensor import ColoTensorSpec -from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.tensor.tensor_spec import ColoTensorSpec -def filter_args(func, *args): - return [arg for arg in args if func(arg)] +def filter_colo_parameters(*args, **kwargs): + param_list = [] + + def get_colo_parameters(element) -> None: + if isinstance(element, list) or isinstance(element, tuple): + for e in element: + get_colo_parameters(e) + elif isinstance(element, dict): + raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.") + elif isinstance(element, ColoParameter): + param_list.append(element) + return + + for a in args: + get_colo_parameters(a) + for v in kwargs.values(): + get_colo_parameters(v) + + return param_list def replace_args(args, kwargs, new_args): @@ -58,18 +75,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): - if ParamOpHookManager.has_hook(): + if ColoParamOpHookManager.has_hook(): if not func.__name__.startswith('__'): if kwargs is None: kwargs = {} - params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) + params = filter_colo_parameters(*args, **kwargs) if len(params) > 0: with torch._C.DisableTorchFunction(): - new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values()) + new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): - ret = ParamOpHookManager.post_op(params, ret) + ret = ColoParamOpHookManager.post_op(params, ret) return ret return super().__torch_function__(func, types, args, kwargs) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index ce6d20c0e..3712d6a0a 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,14 +1,16 @@ -from .op_wrapper import _COLOSSAL_OPS -from .const import TensorType from copy import copy -import torch from functools import lru_cache +from typing import Callable, Optional, Set + +import torch -from colossalai.tensor import ColoTensorSpec -from colossalai.tensor import ProcessGroup, ReplicaSpec from colossalai.tensor.dist_spec_mgr import DistSpecManager -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern -from typing import Optional, Set, Callable +from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec +from colossalai.tensor.process_group import ProcessGroup +from colossalai.tensor.tensor_spec import ColoTensorSpec + +from .const import TensorType +from .op_wrapper import _COLOSSAL_OPS @lru_cache(None) @@ -55,27 +57,29 @@ class ColoTensor(torch.Tensor): The Colotensor can be initialized with a PyTorch tensor in the following ways. >>> pg = ProcessGroup() - >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) + >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())) >>> # The tensor passed in is a tensor after sharding but not a global tensor. - >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), - >>> dims=[0], + >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), + >>> dims=[0], >>> num_partitions=[world_size]) >>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) - + Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). """ + torch_major = int(torch.__version__.split('.')[0]) + torch_minor = int(torch.__version__.split('.')[1]) def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': """ The signature of the __new__ has to be consistent with the torch.Tensor. - + Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (TensorSpec, optional): the tensor spec of initialization. - + Returns: ColoTensor: a ColoTensor wrappers the data. """ @@ -100,7 +104,6 @@ class ColoTensor(torch.Tensor): self.process_group = spec.pg self._type = TensorType.NONMODEL - self._graph_node = None def has_compute_spec(self) -> bool: return self.compute_spec is not None @@ -112,9 +115,9 @@ class ColoTensor(torch.Tensor): return self.process_group def set_process_group(self, pg: ProcessGroup): - """set_process_group + """set_process_group change the pg of the ColoTensor. Note that the valid use cases is limited. - Only existing pg is DP and dist spec is REPLICaTE is valid. + It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica. Args: pg (ProcessGroup): target pg @@ -124,10 +127,10 @@ class ColoTensor(torch.Tensor): # if the new pg is the same as the old pg, just returns if self.process_group == pg: return - assert self.process_group.tp_world_size() == 1, \ - "Can not set_process_group on a ColoTensor whose process_group has tp world group" + assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \ + "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1" assert self.dist_spec.placement.value == 'r', \ - "Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE" + "Can not set_process_group on a ColoTensor whose dist spec is not Replica" self.process_group = pg @@ -135,7 +138,7 @@ class ColoTensor(torch.Tensor): return self.process_group.tp_world_size() def set_dist_spec(self, dist_spec: _DistSpec): - """set_dist_spec + """set_dist_spec set dist spec and change the payloads. Args: @@ -166,6 +169,16 @@ class ColoTensor(torch.Tensor): if func in _COLOSSAL_OPS: func = _COLOSSAL_OPS[func] + if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): + # in order to trigger pre-op hook in the forward of checkpoint module + # we have to capture the `backward` function + # and make sure that it does not in `torch._C.DisableTorchFunction()` context + if func is torch.Tensor.backward: + assert len(args) == 1 # only has 1 paramter + backward_tensor = torch.Tensor(args[0]) + tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} + return backward_tensor.backward(**tensor_kwargs) + with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in _get_my_nowrap_functions(): @@ -178,7 +191,7 @@ class ColoTensor(torch.Tensor): return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' def _redistribute(self, dist_spec: _DistSpec) -> None: - """_redistribute + """_redistribute Note the function will not handle the logic of backward propagation! It is used during model tensor initializations as an internal function. @@ -191,12 +204,12 @@ class ColoTensor(torch.Tensor): self.dist_spec = dist_spec def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': - """redistribute + """redistribute Redistribute the tensor among processes. The rule is like this: - + 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the DP process group not changed. - + 2. If the pg is not not None and not equal to the current process group. First, convert the tensor as replicated among the TP process group. Second, reset the process group to the new pg. @@ -220,7 +233,7 @@ class ColoTensor(torch.Tensor): return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) def to_replicate_(self): - """to_replicate_ + """to_replicate_ an inline member function, converting dist spec of the tensor to REPLICATE """ diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 617057a4f..3c9e0fd56 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -23,9 +23,9 @@ def _all_gather(tensor, comm_spec): torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) ] - tensor = tensor - group = process_group - dist.all_gather(tensor_list, tensor, group=group) + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() return output @@ -37,11 +37,10 @@ def _split(tensor, comm_spec): process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, _ in process_groups_list: if dist.get_rank() in rank_list: - tensor = tensor dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // len(rank_list) start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length) + output = torch.narrow(tensor, dim, start, length).contiguous() return output @@ -69,17 +68,145 @@ def _all_to_all(tensor, comm_spec): return output -def _all_reduce(tensor, comm_spec): +def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) return tensor +def _mix_gather(tensor, comm_spec): + ''' + Implement mix gather operation on device mesh based on information provided by comm_spec. + Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is + different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather + only does all-gather in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + ShardingSpec => gather_dim, logical_process_axes + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (1, 1) + RS01 => [b], (1, 1) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + S0S1: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...] + mesh_shape = (2,4) + cat_slice = [4,2] + tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b) + output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a) + S1S0: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)] + mesh_shape = (2,4) + cat_slice = [2,4] + tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b) + tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b) + tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b) + S10R: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + S01R: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + ''' + total_slices = comm_spec.device_mesh.mesh_shape[0] + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] + leading_group_dim = comm_spec.logical_process_axes[0] + assert len(comm_spec.device_mesh.process_groups_dict) == 1 + _, process_group = comm_spec.device_mesh.process_groups_dict[0][0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + + # Global all_gather + dist.all_gather(tensor_list, tensor, group=process_group) + + # This is very ugly. I'm figuring out more elegant methods + tensor_list_sorted = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices) + ] + for i in range(total_slices): + tensor_list_sorted[i] = tensor_list[process_number_list[i]] + tensor_list = tensor_list_sorted + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() + else: + mesh_shape = comm_spec.device_meshes.mesh_shape + cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + tmp_tensor_shape = list(tensor.shape) + tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] + tmp_tensor_shape = torch.Size(tmp_tensor_shape) + tmp_tensor_list = [ + torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) + ] + for i in range(cat_slice[1]): + tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), + comm_spec.gather_dim[0]).contiguous() + output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() + + return output + + +def _mix_split(tensor, comm_spec): + ''' + Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) + Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split + because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (0, 0) + RS01 => [b], (0, 0) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + mesh_shape = comm_spec.device_meshes.mesh_shape + dim = comm_spec.gather_dim + total_slices = comm_spec.device_mesh.mesh_shape[0] + + # Get global rank + rank = dist.get_rank() + + leading_group_dim = comm_spec.logical_process_axes[0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + rank = process_number_list.index(rank) + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + length = tensor.shape[dim[0]] // total_slices + start = length * rank + output = torch.narrow(tensor, dim[0], start, length).contiguous() + else: + tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]] + rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]] + start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]] + tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous() + output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous() + + return output + + class _ReduceGrad(torch.autograd.Function): """ A customized communication operation which forward is an identity operation, @@ -205,6 +332,22 @@ class _AllToAll(torch.autograd.Function): return _all_to_all(grad_outputs, ctx.comm_spec), None +class _MixGatherForwardMixSplitBackward(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_): + return _mix_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _mix_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _mix_split(grad_output, ctx.comm_spec), None + + def reduce_grad(input_, comm_spec): return _ReduceGrad.apply(input_, comm_spec) @@ -225,12 +368,17 @@ def all_to_all(input_, comm_spec): return _AllToAll.apply(input_, comm_spec) +def mixgather_forward_split_backward(input_, comm_spec): + return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec) + + class CollectiveCommPattern(Enum): GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: @@ -256,7 +404,8 @@ class CommSpec: gather_dim=None, shard_dim=None, logical_process_axis=None, - forward_only=False): + forward_only=False, + mix_gather=False): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim @@ -264,8 +413,14 @@ class CommSpec: self.logical_process_axis = logical_process_axis self.forward_only = forward_only if isinstance(self.logical_process_axis, list): - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh - self.logical_process_axis = 0 + if not mix_gather: + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + # Create a new member `logical_process_axes` to distinguish from original flatten + self.logical_process_axes = logical_process_axis else: self.device_mesh = self.sharding_spec.device_mesh @@ -290,6 +445,10 @@ class CommSpec: elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_asex:{self.logical_process_axes})") return ''.join(res_list) @@ -325,6 +484,11 @@ class CommSpec: forward_communication_cost = 10 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + # no need for axis because all devices are used in mix_gather + forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) + backward_communication_cost = 10 + if self.forward_only: cost_dict["forward"] = forward_communication_cost cost_dict["backward"] = 0 @@ -357,4 +521,5 @@ pattern_to_func_dict = { CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward, } diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index f1dc241a8..d5c0ce28e 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -1,12 +1,14 @@ -from colossalai.tensor.distspec import _DistSpec -# from colossalai.nn.layer.utils import divide -from numpy import prod from contextlib import contextmanager + import torch import torch.distributed as dist +# from colossalai.nn.layer.utils import divide +from numpy import prod from packaging import version + from colossalai.logging import get_dist_logger -from colossalai.tensor import ProcessGroup +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.process_group import ProcessGroup # TODO(jiaruifang) circle import, move the divide to colossalai.commons. diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 03cb090a6..ed705da0e 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -1,16 +1,19 @@ -import torch -from contextlib import contextmanager from abc import ABC, abstractmethod -from typing import List, Tuple, Any +from contextlib import contextmanager +from typing import Any, List, Tuple + +import torch + from colossalai.tensor.colo_tensor import ColoTensor -from colossalai.tensor import ColoTensorSpec +from colossalai.tensor.tensor_spec import ColoTensorSpec -class ParamOpHook(ABC): - """Hook which is triggered by each operation when operands contain ColoParameter. +class ColoParamOpHook(ABC): + """ + Hook which is triggered by each operation when operands contain ColoParameter. To customize it, you must inherit this abstract class, and implement ``pre_forward``, - ``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list - of ColoParameter. + ``post_forward``, ``pre_backward`` and ``post_backward``. + These four methods apply a list of ColoParameter as input args. """ @abstractmethod @@ -30,68 +33,79 @@ class ParamOpHook(ABC): pass -class ParamOpHookManager: - """Manage your param op hooks. It only has static methods. +class ColoParamOpHookManager: + """ + Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ - hooks: Tuple[ParamOpHook, ...] = tuple() + hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @contextmanager - def use_hooks(*hooks: ParamOpHook): + def use_hooks(*hooks: ColoParamOpHook): """Change the param op hooks you use. Nested calling is allowed. Example: - >>> with ParamOpHookManager.use_hooks(*hooks): + >>> with ColoParamOpHookManager.use_hooks(*hooks): >>> do_something() - >>> with ParamOpHookManager.use_hooks(): + >>> with ColoParamOpHookManager.use_hooks(): >>> // clear hooks >>> do_something() """ try: - old_param_op_hooks = ParamOpHookManager.hooks - ParamOpHookManager.hooks = hooks + old_param_op_hooks = ColoParamOpHookManager.hooks + ColoParamOpHookManager.hooks = hooks yield finally: - ParamOpHookManager.hooks = old_param_op_hooks + ColoParamOpHookManager.hooks = old_param_op_hooks @staticmethod def _trigger_pre_forward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.pre_forward(params) @staticmethod def _trigger_post_forward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.post_forward(params) @staticmethod def _trigger_pre_backward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.pre_backward(params) @staticmethod def _trigger_post_backward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.post_backward(params) @staticmethod def pre_op(params: List[torch.Tensor], *args: Any) -> list: - ParamOpHookManager._trigger_pre_forward(params) - args_info = _get_colo_tensors_info(*args) - rets = PreFwdPostBwd.apply(params, *args) - return _update_colo_tensors(args_info, *rets) + ColoParamOpHookManager._trigger_pre_forward(params) + grad_args, rear_args = _get_grad_args(*args) + colo_info = _get_colo_tensors_info(*grad_args) + rets = PreFwdPostBwd.apply(params, *grad_args) + update_args = _update_colo_tensors(colo_info, *rets) + if rear_args is None: + return update_args + else: + arg_zero = (tuple(update_args),) + return arg_zero + rear_args @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: - ParamOpHookManager._trigger_post_forward(params) - arg_info = _get_colo_tensors_info(arg) + ColoParamOpHookManager._trigger_post_forward(params) + colo_info = _get_colo_tensors_info(arg) ret = PostFwdPreBwd.apply(params, arg) - return _unpack_args(_update_colo_tensors(arg_info, ret)) + res = _update_colo_tensors(colo_info, ret) + if len(res) == 1: + return res[0] + else: + return res @staticmethod def has_hook() -> bool: - return len(ParamOpHookManager.hooks) > 0 + return len(ColoParamOpHookManager.hooks) > 0 class PreFwdPostBwd(torch.autograd.Function): @@ -99,11 +113,11 @@ class PreFwdPostBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, *args): ctx.params = params - return _unpack_args(args) + return args @staticmethod def backward(ctx, *grads): - ParamOpHookManager._trigger_post_backward(ctx.params) + ColoParamOpHookManager._trigger_post_backward(ctx.params) return (None,) + grads @@ -116,14 +130,51 @@ class PostFwdPreBwd(torch.autograd.Function): @staticmethod def backward(ctx, *grads): - ParamOpHookManager._trigger_pre_backward(ctx.params) + ColoParamOpHookManager._trigger_pre_backward(ctx.params) return (None,) + grads -def _unpack_args(args): - if len(args) == 1: - return args[0] - return args +def _is_grad_tensor(obj) -> bool: + if torch.is_tensor(obj): + if obj.grad_fn is not None or obj.requires_grad: + return True + return False + + +def _has_grad_tensor(obj) -> bool: + if isinstance(obj, tuple) or isinstance(obj, list): + for x in obj: + if _has_grad_tensor(x): + return True + return False + elif isinstance(obj, dict): + for x in obj.values(): + if _has_grad_tensor(x): + return True + return False + else: + return _is_grad_tensor(obj) + + +def _get_grad_args(*args): + # if there is no grad tensors, do nothing + if not _has_grad_tensor(args): + return args, None + # returns the identical args if there is a grad tensor + for obj in args: + if _is_grad_tensor(obj): + return args, None + # otherwise, the first arguement should be a tuple of grad tensors + # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered + arg_zero = args[0] + if not isinstance(arg_zero, tuple): + raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + check_grad_flag = False + for obj in arg_zero: + check_grad_flag |= _is_grad_tensor(obj) + if not check_grad_flag: + raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + return arg_zero, args[1:] def _get_colo_tensors_info(*args) -> list: diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d96040817..2831b10a3 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -1,18 +1,15 @@ import math -import operator from copy import deepcopy from dataclasses import dataclass -from enum import Enum -from functools import reduce -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple +import numpy as np import torch -import torch.distributed as dist -from torch.distributed import ReduceOp +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec -from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator from .comm_spec import * @@ -28,6 +25,15 @@ class ShapeConsistencyOptions: pass +def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: + shape_consistency_manager = ShapeConsistencyManager() + global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) + with torch.no_grad(): + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, + global_sharding_spec) + return global_tensor + + def set_shape_consistency_options(options: ShapeConsistencyOptions): """ Configure the shape consistency manager via function call. @@ -63,7 +69,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec, orig_cost_dict): + def get_all_all_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -71,7 +78,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. - orig_cost(float): the original communication cost before this operation. + orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. @@ -83,7 +90,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -134,7 +141,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): + def get_all_all_to_all_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -142,7 +150,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. - orig_cost(float): the original communication cost before this operation. + orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. @@ -154,7 +162,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -241,7 +249,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return valid_spec_dict - def get_all_shard_spec(self, source_spec, orig_cost_dict): + def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -261,7 +269,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() - rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: @@ -322,7 +330,60 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): + def get_all_mix_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + S0S1 -> RR + S1S0 -> RR + S01R -> RR + RS01 -> RR + ''' + valid_spec_dict = {} + comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict): + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S10, R) -> (R, R) + if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S10) -> (R, R) + if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + comm_spec = CommSpec(comm_pathern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True) + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = {} + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -344,7 +405,167 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) return valid_spec_dict - def shape_consistency(self, source_spec, target_spec): + def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: + """memory cost of the communication action sequence + + Args: + comm_action_sequence (List[CommSpec]): list of communication actions + + Returns: + TrainCycleItem: memory (numel) cost of such comm_action_sequence + """ + + def compute_shape(sharding_spec: ShardingSpec): + shape = sharding_spec.entire_shape + new_shape = [] + for dim, shard in sharding_spec.dim_partition_dict.items(): + new_shape.append(shape[dim] // len(shard)) + return new_shape + + def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_gather memory footprint + all_gather will allocate memory for the output tensor, and there will be temp memory for + all_gather operation, which is twice the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + return alloc_numel, peak_numel + + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze split memory footprint + split will allocate memory for the output tensor if we don't apply shard on the first dimension of + the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not + generate new tensor in this case, so no memory will be allocated. + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + # if we don't shard the tensor on the first dimension, the split action will + # generate a new tensor + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + alloc_numel += output_numel + peak_numel = max(peak_numel, alloc_numel) + if discard_input: + alloc_numel -= input_numel + else: + # if we shard the tensor on the first dimension, the split action will not generate + # a new tensor, and as it will preserve a reference to the input tensor, we could + # override the discard_input option here + # NOTE: this special case might fail in some weird cases, e.g. if we have three split + # actions in the comm actions sequence, the first split action operate on the second dimension, + # the second split action operate on the first dimension, and the third split action operate, again, + # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate + # memory the same size as the output of first split action. However, the third split action will discard + # the input tensor, and it actually should discard the tensor generated by the first split action, so in + # the current memory estimation framework, we will overestimate the memory usage. But the above case is + # kind of weird, and I think we could ignore it for now. + pass + + return alloc_numel, peak_numel + + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory + """ + return alloc_numel, peak_numel + + def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_to_all memory footprint + all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action + is twice the size of output tensor if we shard input tensor on the first dimension, otherwise + the temp memory is three times the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + peak_numel = max(peak_numel, alloc_numel + output_numel * 3) + else: + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + return alloc_numel, peak_numel + + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory + """ + return alloc_numel, peak_numel + + pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis], + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis], + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis], + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis], + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [], + } + + fwd_actions = [] + bwd_actions = [] + + # construct forward and backward comm actions sequence + for comm_spec in comm_action_sequence: + comm_spec: CommSpec + fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern] + fwd_actions.append(fwd_action) + bwd_actions.append(bwd_action) + + # analyze memory footprint of forward comm actions sequence + fwd_alloc_numel = 0 + fwd_peak_numel = 0 + for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): + # the first forward comm action will not discard input + fwd_action, comm_spec = action_spec_pair + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, + fwd_peak_numel) if idx == 0 else fwd_action( + comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + + # analyze memory footprint for backward comm actions sequence + bwd_alloc_numel = 0 + bwd_peak_numel = 0 + for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): + bwd_action, comm_spec = action_spec_pair + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, + bwd_peak_numel) if idx == 0 else bwd_action( + comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + + fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) + bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) + total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel) + + return TrainCycleItem(fwd_mem, bwd_mem, total_mem) + + def shape_consistency(self, source_spec: ShardingSpec, + target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: ''' This method will find a path to transform source_spec to target_spec with a greedy algorithm. @@ -450,7 +671,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") - def apply(self, tensor_with_sharding_spec, target_spec): + def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: ''' Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index fababb6e7..cdd033885 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,12 +1,12 @@ import operator from copy import deepcopy -from enum import Enum from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator) + +from .utils import merge_same_dim_mesh_list __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] @@ -23,7 +23,7 @@ class _DimSpec: This class is used internally in ShardingSpec. Argument: - shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. ''' @@ -62,7 +62,7 @@ class _DimSpec: def build_difference_2d_dict(self): ''' - Build a difference maping for 2D device mesh case. It will be used to + Build a difference maping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. ''' @@ -159,9 +159,9 @@ class ShardingNotDivisibleError(ShardingSpecException): class ShardingSpec: ''' Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong - to, the entire shape of the tensor before sharded, and the sharding sequence looks like + to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. - + Argument: device_mesh(DeviceMesh): A logical view of a physical mesh. entire_shape(torch.Size): The entire shape of tensor before sharded. @@ -176,12 +176,19 @@ class ShardingSpec: dim_partition_dict=None, sharding_sequence=None): self.device_mesh = device_mesh + + if isinstance(entire_shape, (list, tuple)): + entire_shape = torch.Size(entire_shape) self.entire_shape = entire_shape self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), + dim_partition_dict=self.dim_partition_dict) self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' self.convert_shard_sequence_to_dict() self._sanity_check() @@ -260,10 +267,10 @@ class ShardingSpec: # device_mesh_shape: (4, 4) sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) - + Output: 25 - + Argument: other(ShardingSpec): The ShardingSpec to compared with. diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index 23dd3b9af..580df9f8f 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -1,14 +1,16 @@ -from typing import Optional -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern -from .compute_spec import ComputeSpec -from colossalai.tensor import ProcessGroup from dataclasses import dataclass +from typing import Optional + +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor.process_group import ProcessGroup + +from .compute_spec import ComputeSpec @dataclass class ColoTensorSpec: """ ColoTensorSpec - + A data class for specifications of the `ColoTensor`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index b2eda5a8d..0c2ead630 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -1,7 +1,8 @@ -import torch +from typing import Dict, Iterator, List, Tuple, Union -from typing import Iterator, Tuple, Union +import torch import torch.nn as nn + from colossalai.tensor.colo_tensor import ColoTensor @@ -12,7 +13,7 @@ def all_gather_simulator(target_pair): We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. Therefore, all gather operation just remove the last element in shard list, - e.g.: + e.g.: all-gather(S01) -> S0 Argument: @@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair): and simulate the influence of the DimSpec. We BANNED all representations which shard_list in decreasing order, - such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element decribes which logical axis will be sharded in that dimension. - e.g.: + e.g.: all-to-all(S0, S1) -> [S01, R] all-to-all(S0, R) -> [R, S0] Otherwise, we extend the front shard_list to behind. - e.g.: + e.g.: all-to-all(R, S1) -> [S1, R] - + Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element decribes which logical axis will be sharded in that dimension. @@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims): and simulate the influence of the DimSpec. We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. - In addition, We BANNED all representations which shard_list in decreasing order, + In addition, We BANNED all representations which shard_list in decreasing order, such as S10, so shard(S0) -> S10 is NOT allowed. Therefore, for the R dimension, we could just append any legal sharding dim on it. e.g.: @@ -89,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims): return shard_list_list +def mix_gather_simulator(f_target_pair, b_target_pair): + ''' + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) + S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) + S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1) + RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) + S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) + RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) + ''' + if f_target_pair[1] and b_target_pair[1]: + leading_dim = b_target_pair[1] > f_target_pair[1] + return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] + if f_target_pair[1]: + leading_dim = f_target_pair[1][0] < f_target_pair[1][1] + return [ + f_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + if b_target_pair[1]: + leading_dim = b_target_pair[1][0] < b_target_pair[1][1] + return [ + b_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + + # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, @@ -164,3 +190,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str): # Now we can set the attribute appropriately. setattr(module, param_name, st) + + +def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to convert the negative dim value to positive. + ''' + dims_to_convert = [] + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dims_to_convert.append(dim) + for dim in dims_to_convert: + dim_partition_dict.pop(dim) + dim_partition_dict[dim_size + dim] = mesh_list + return dim_partition_dict + + +def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to merge the different key value which points to same physical position. + + For example: + dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. + In this method, above dim_partition_dict will be converted to {1: [0, 1]} + ''' + converted_dim_partition_dict = {} + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dim = dim_size + dim + if dim not in converted_dim_partition_dict: + converted_dim_partition_dict[dim] = mesh_list + else: + converted_dim_partition_dict[dim].extend(mesh_list) + + return converted_dim_partition_dict diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index de4f460c0..e00d0da16 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist from torch import Tensor from torch.distributed import ProcessGroup +from torch.testing import assert_close def assert_equal(a: Tensor, b: Tensor): @@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor): assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' -def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8): - assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}' - - def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, b, rtol, atol) + assert_close(a, b, rtol=rtol, atol=atol) def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py new file mode 100644 index 000000000..ad6d24a4b --- /dev/null +++ b/colossalai/testing/random.py @@ -0,0 +1,19 @@ +import random + +import numpy as np +import torch + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True diff --git a/colossalai/trainer/hooks/_mem_tracer_hook.py b/colossalai/trainer/hooks/_mem_tracer_hook.py deleted file mode 100644 index 29c5d9b3c..000000000 --- a/colossalai/trainer/hooks/_mem_tracer_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -from colossalai.registry import HOOKS -from torch import Tensor -from colossalai.trainer.hooks import BaseHook -from colossalai.gemini.memory_tracer import AsyncMemoryMonitor - - -@HOOKS.register_module -class MemTraceHook(BaseHook): - """Save memory stats and pass it to states - This hook is used to record memory usage info, and pass to trainer.states - You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] - """ - - def __init__( - self, - priority: int = 0, - ) -> None: - super().__init__(priority=priority) - self._memory_monitor = AsyncMemoryMonitor() - - def after_hook_is_attached(self, trainer): - # Initialize the data - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - - def before_train_iter(self, trainer): - self._memory_monitor.start() - return super().before_train_iter(trainer) - - def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_train_iter(trainer, output, label, loss) - - def before_test_iter(self, trainer): - self._memory_monitor.start() - return super().before_test(trainer) - - def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_test_iter(trainer, output, label, loss) diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py new file mode 100644 index 000000000..fe0308668 --- /dev/null +++ b/colossalai/utils/checkpoint_io/__init__.py @@ -0,0 +1,2 @@ +from .io import load, merge, redist, save +from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py new file mode 100644 index 000000000..140192c05 --- /dev/null +++ b/colossalai/utils/checkpoint_io/backend.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from abc import ABC, abstractmethod +from typing import Dict, List, Type + +from .reader import CheckpointReader, DiskCheckpointReader +from .writer import CheckpointWriter, DiskCheckpointWriter + +_backends: Dict[str, Type['CheckpointIOBackend']] = {} + + +def register(name: str): + assert name not in _backends, f'"{name}" is registered' + + def wrapper(cls): + _backends[name] = cls + return cls + + return wrapper + + +def get_backend(name: str) -> 'CheckpointIOBackend': + assert name in _backends, f'Unsupported backend "{name}"' + return _backends[name]() + + +class CheckpointIOBackend(ABC): + + def __init__(self) -> None: + super().__init__() + self.temps: List[str] = [] + + @abstractmethod + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + pass + + @abstractmethod + def get_reader(self, base_name: str) -> CheckpointReader: + pass + + @abstractmethod + def get_temp(self, base_name: str) -> str: + pass + + @abstractmethod + def clean_temp(self) -> None: + pass + + +@register('disk') +class CheckpointDiskIO(CheckpointIOBackend): + + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) + + def get_reader(self, base_name: str) -> CheckpointReader: + return DiskCheckpointReader(base_name) + + def get_temp(self, base_name: str) -> str: + temp_dir_name = tempfile.mkdtemp(dir=base_name) + self.temps.append(temp_dir_name) + return temp_dir_name + + def clean_temp(self) -> None: + for temp_dir_name in self.temps: + shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py new file mode 100644 index 000000000..219948474 --- /dev/null +++ b/colossalai/utils/checkpoint_io/constant.py @@ -0,0 +1,9 @@ +import re + +GLOBAL_META_FILE_NAME = 'global_meta.bin' +MODEL_CKPT_FILE_NAME = 'model.bin' +OPTIM_CKPT_FILE_NAME = 'optim.bin' +META_CKPT_FILE_NAME = 'meta.bin' +OTHER_CKPT_FILE_NAME = 'other.bin' + +CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py new file mode 100644 index 000000000..529ceb868 --- /dev/null +++ b/colossalai/utils/checkpoint_io/convertor.py @@ -0,0 +1,227 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +from torch import Tensor + +from .distributed import merge_param, unmerge_param +from .meta import ParamDistMeta, RedistMeta +from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) + + +class CheckpointConvertor(ABC): + + @abstractmethod + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + pass + + @abstractmethod + def complete(self) -> None: + pass + + +class ModelCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int]) -> None: + super().__init__() + self.param_count = param_count + self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) + + @abstractmethod + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + for k, tensor in state_dict.items(): + self.buffer[k][rank] = tensor + converted_keys = set() + for k, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[k]: + tensors = [] + dist_metas = [] + for rank, tensor in rank_dict.items(): + tensors.append(tensor) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][k]) + self.convert_tensors(k, tensors, dist_metas) + converted_keys.add(k) + for k in converted_keys: + del self.buffer[k] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class ModelCheckpointMerger(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: + super().__init__(param_count) + self.sharder = ModelCheckpointSharder(max_shard_size) + self.save_fn = save_fn + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + shard = self.sharder.append(key, tensor) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class ModelCheckpointRedistor(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + redist_meta: RedistMeta) -> None: + super().__init__(param_count) + self.save_fns = save_fns + self.redist_meta = redist_meta + nprocs = len(save_fns) + self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + if len(dist_metas) == 0: + # already global + tensor = tensors[0] + else: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[key][tp_rank][dp_rank]: + shard = self.sharders[rank].append(key, t) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) + + +class OptimizerCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], + paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__() + self.param_count = param_count + self.param_to_os = param_to_os + self.paired_os = paired_os + self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) + self.os_to_param = {v: k for k, v in param_to_os.items()} + + @abstractmethod + def setup(self, param_groups: dict) -> None: + pass + + @abstractmethod + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + self.setup(state_dict['param_groups']) + for idx, state in state_dict['state'].items(): + self.buffer[idx][rank] = state + converted_indices = set() + for idx, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[self.os_to_param[idx]]: + states = [] + dist_metas = [] + for rank, state in rank_dict.items(): + states.append(state) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) + self.convert_states(idx, states, dist_metas) + converted_indices.add(idx) + for idx in converted_indices: + del self.buffer[idx] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fn = save_fn + self.sharder = None + + def setup(self, param_groups: dict) -> None: + if self.sharder is None: + self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(states) + new_state = {} + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) + else: + new_state[state_key] = state_tensor + shard = self.sharder.append(idx, new_state) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], + redist_meta: RedistMeta) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fns = save_fns + self.redist_meta = redist_meta + self.sharders: List[OptimizerCheckpointSharder] = [] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def setup(self, param_groups: dict) -> None: + if len(self.sharders) == 0: + nprocs = len(self.save_fns) + for _ in range(nprocs): + self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + need_merge: bool = True + if len(dist_metas) == 0: + need_merge = False + else: + assert len(dist_metas) == len(states) + new_states = [{} for _ in range(len(self.save_fns))] + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + if need_merge: + tensor = merge_param([state[state_key] for state in states], dist_metas) + else: + tensor = state_tensor + for tp_rank, tensor_list in enumerate( + unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: + new_states[rank][state_key] = t + else: + for new_state in new_states: + new_state[state_key] = state_tensor + for rank, new_state in enumerate(new_states): + shard = self.sharders[rank].append(idx, new_state) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py new file mode 100644 index 000000000..bf720437c --- /dev/null +++ b/colossalai/utils/checkpoint_io/distributed.py @@ -0,0 +1,127 @@ +import torch +from numpy import prod +from torch import Tensor +from typing import List, Optional, Tuple +from collections import defaultdict +from .meta import ParamDistMeta, ParamRedistMeta + + +def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' + if not dist_metas[0].used_zero: + # tensors are replicate + return tensors[0] + numel = dist_metas[0].zero_numel + orig_shape = dist_metas[0].zero_orig_shape + tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] + assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' + return torch.cat(tensors).reshape(orig_shape) + + +def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' + for t in tensors[1:]: + assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' + if not dist_metas[0].used_tp: + # tensors are replicate + return tensors[0] + total_parts = prod(dist_meta.tp_num_parts) + assert dist_meta.tp_world_size == total_parts, \ + f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' + shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) + for dim, num_parts in shard_info: + buffer = [] + for start in range(0, len(tensors), num_parts): + buffer.append(torch.cat(tensors[start:start + num_parts], dim)) + tensors = buffer + assert len(tensors) == 1 + return tensors[0] + + +def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) > 0 + # check world size + for dist_meta in dist_metas[1:]: + assert dist_meta.dp_world_size == dist_metas[ + 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' + assert dist_meta.tp_world_size == dist_metas[ + 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' + + +def deduplicate_params(tensors: List[Tensor], + dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: + unique_dist_meta = [] + unique_idx = [] + for i, dist_meta in enumerate(dist_metas): + if dist_meta not in unique_dist_meta: + unique_dist_meta.append(dist_meta) + unique_idx.append(i) + return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] + + +def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + # validate parallel info + validate_parallel_info(dist_metas) + tensors, dist_metas = deduplicate_params(tensors, dist_metas) + unflattened_tensors = [] + # group zero params by tp rank + tensor_dict = defaultdict(list) + dist_meta_dict = defaultdict(list) + for t, dist_meta in zip(tensors, dist_metas): + tensor_dict[dist_meta.tp_rank].append(t) + dist_meta_dict[dist_meta.tp_rank].append(dist_meta) + assert len(tensor_dict + ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' + for tp_rank in tensor_dict.keys(): + unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) + return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) + + +def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_tp: + assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' + return [tensor] + total_parts = prod(redist_meta.tp_num_parts) + assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' + shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) + tensors = [tensor] + for dim, num_parts in shard_info: + buffer = [] + for t in tensors: + assert t.size(dim) % num_parts == 0, \ + f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' + chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] + buffer.extend(chunks) + tensors = buffer + assert len(tensors) == redist_meta.tp_world_size + return tensors + + +def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_zero: + return [tensor] * redist_meta.dp_world_size + tensors: List[Optional[Tensor]] = [ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) + ] + offsets = redist_meta.zero_offsets + [tensor.numel()] + for i, offset in enumerate(offsets[:-1]): + end = offsets[i + 1] + tensors.append(tensor.view(-1)[offset:end]) + if len(tensors) < redist_meta.dp_world_size: + tensors.extend([ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) + for _ in range(redist_meta.dp_world_size - len(tensors)) + ]) + assert len(tensors) == redist_meta.dp_world_size + return tensors + + +def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: + tensors = split_tp_param(tensor, redist_meta) + tensors = [flatten_zero_param(t, redist_meta) for t in tensors] + return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py new file mode 100644 index 000000000..f00212cdf --- /dev/null +++ b/colossalai/utils/checkpoint_io/io.py @@ -0,0 +1,170 @@ +import warnings +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch.distributed as dist +from torch.nn import Module +from torch.optim import Optimizer + +from .backend import get_backend +from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, + OptimizerCheckpointRedistor) +from .meta import ParamDistMeta, RedistMeta +from .utils import build_checkpoints, optimizer_load_state_dict + + +def save(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk', + **kwargs: Any) -> None: + io_backend = get_backend(backend) + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + if world_size == 1: + # global doesn't need dist_meta + dist_meta = None + else: + assert dist_meta is not None + max_shard_size = int(max_shard_size_gb * 1024**3) + model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, + param_to_os, dist_meta) + writer = io_backend.get_writer(path, overwrite, rank, world_size) + writer.save_others(kwargs) + for model_checkpoint in model_checkpoints: + writer.save_model(model_checkpoint) + for optimizer_checkpoint in optimizer_checkpoints: + writer.save_optimizer(optimizer_checkpoint) + writer.save_meta(meta_checkpoint) + + +def merge(path: str, + output_path: str, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + reader = io_backend.get_reader(path) + if len(reader.meta_list) == 1: + # already global + warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') + return False + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + writer = io_backend.get_writer(output_path, overwrite=overwrite) + writer.save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), + dist_meta_list) + _convert_shards( + OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), + reader.load_optimizers(), dist_meta_list) + meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def redist(path: str, + output_path: str, + redist_meta: RedistMeta, + dist_metas: List[Dict[str, ParamDistMeta]], + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + nprocs = len(dist_metas) + reader = io_backend.get_reader(path) + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + do_redist: bool = False + if len(dist_meta_list) == nprocs: + for a, b in zip(dist_metas, dist_meta_list): + if a != b: + do_redist = True + break + else: + do_redist = True + if not do_redist: + warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') + return False + + writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] + writers[0].save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards( + ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), + reader.load_models(), dist_meta_list) + _convert_shards( + OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, + param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) + for writer, dist_meta in zip(writers, dist_metas): + meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], + dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for shard_dict in shard_generator: + convertor.append(shard_dict, dist_meta_list) + convertor.complete() + + +def load(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + redist_meta: Optional[RedistMeta] = None, + dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, + max_shard_size_gb: float = 0.0, + backend: str = 'disk') -> dict: + is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 + rank: int = dist.get_rank() if dist.is_initialized() else 0 + is_main_process: bool = rank == 0 + # validate args + if redist_meta is None or dist_metas is None: + assert is_global + io_backend = get_backend(backend) + read_path: str = path + if is_main_process: + # pre-process checkpoints + temp_path = io_backend.get_temp(path) + if is_global: + wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) + else: + wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) + if wrote: + read_path = temp_path + if not is_global: + bcast_list = [read_path] if is_main_process else [None] + dist.broadcast_object_list(bcast_list) + read_path = bcast_list[0] + reader = io_backend.get_reader(read_path) + # load model + for shard in reader.load_model(rank): + model.load_state_dict(shard, strict=False) + if optimizer is not None: + for shard in reader.load_optimizer(rank): + # optimizer.load_state_dict(shard) + optimizer_load_state_dict(optimizer, shard) + others_dict = reader.load_others() + if not is_global: + dist.barrier() + # clean up temp + if is_main_process: + io_backend.clean_temp() + return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py new file mode 100644 index 000000000..994f08b4b --- /dev/null +++ b/colossalai/utils/checkpoint_io/meta.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import List, Optional, Set, Dict + + +@dataclass +class ParamDistMeta: + # parallel info + dp_rank: int + dp_world_size: int + tp_rank: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_numel: Optional[int] = None + zero_orig_shape: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_numel is not None and self.zero_orig_shape is not None + + @property + def parallel_meta(self) -> tuple: + return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size + + @property + def tp_meta(self) -> tuple: + return self.tp_shard_dims, self.tp_num_parts + + @property + def zero_meta(self) -> tuple: + return self.zero_numel, self.zero_orig_shape + + @staticmethod + def from_dict(d: dict) -> 'ParamDistMeta': + return ParamDistMeta(**d) + + +@dataclass +class ParamRedistMeta: + # parallel info + dp_world_size: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_start_dp_rank: Optional[int] = None + zero_offsets: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_start_dp_rank is not None and self.zero_offsets is not None + + +@dataclass +class RankRedistMeta: + dp_rank: int + tp_rank: int + pp_rank: int + + +@dataclass +class PipelineRedistMeta: + params: Set[str] + + +@dataclass +class RedistMeta: + rank_meta: Dict[str, Dict[int, RankRedistMeta]] + pipeline_meta: List[PipelineRedistMeta] + param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py new file mode 100644 index 000000000..3158c6481 --- /dev/null +++ b/colossalai/utils/checkpoint_io/reader.py @@ -0,0 +1,131 @@ +import os +from abc import ABC, abstractmethod +from collections import Counter +from typing import Dict, Generator, List, Optional, Tuple + +import torch + +from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME +from .meta import ParamDistMeta +from .utils import is_duplicated_list + + +class CheckpointReader(ABC): + + def __init__(self, base_name: str) -> None: + super().__init__() + self.base_name = base_name + self.meta_list = [] + + @abstractmethod + def read(self, name: str) -> dict: + pass + + @abstractmethod + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + pass + + @abstractmethod + def load_model(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_models(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_others(self) -> dict: + pass + + +class DiskCheckpointReader(CheckpointReader): + + def __init__(self, base_name: str) -> None: + super().__init__(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + global_meta = self.read(GLOBAL_META_FILE_NAME) + for meta_file_name in global_meta['meta']: + meta = self.read(meta_file_name) + if meta.get('dist_meta', None) is None: + # only global checkpoint can have empty dist_meta + assert len(global_meta['meta']) == 1 + self.meta_list.append(meta) + + def read(self, name: str) -> dict: + return torch.load(os.path.join(self.base_name, name)) + + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', + None), meta.get('paired_os', None)) + for meta in self.meta_list] + dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) + # reduce param_count + param_count = Counter(p for params in params_list for p in params) + # validate param_to_os + assert is_duplicated_list(param_to_os_list) + assert is_duplicated_list(paired_os_list) + return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] + + def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: + meta = self.meta_list[rank] + checkpoint_names = meta.get(shard_type, []) + for name in checkpoint_names: + yield self.read(name) + + def load_model(self, rank: int) -> Generator[dict, None, None]: + return self._load_shard('model', rank) + + def load_models(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + model_checkpoint_names = meta.get('model', []) + if indices[i] < len(model_checkpoint_names): + shards[i] = self.read(model_checkpoint_names[indices[i]]) + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + param_groups = None + for shard in self._load_shard('optimizer', rank): + if param_groups is None: + param_groups = shard['param_groups'] + else: + shard['param_groups'] = param_groups + yield shard + + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + param_groups = [] + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + optimizer_checkpoint_names = meta.get('optimizer', []) + if indices[i] < len(optimizer_checkpoint_names): + shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) + if indices[i] == 0: + param_groups.append(shards[i]['param_groups']) + else: + shards[i]['param_groups'] = param_groups[i] + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_others(self) -> dict: + return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py new file mode 100644 index 000000000..135385f57 --- /dev/null +++ b/colossalai/utils/checkpoint_io/utils.py @@ -0,0 +1,223 @@ +import warnings +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + +from .meta import ParamDistMeta + + +def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: + if arg is not None: + return fn(arg) + + +def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: + # ensure all params in optimizer are in model state dict + params_set = set(id(p) for p in model.parameters()) + for group in optimizer.param_groups: + for p in group['params']: + assert id(p) in params_set + param_mappings = {} + start_index = 0 + + def get_group_mapping(group): + nonlocal start_index + param_mappings.update( + {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) + start_index += len(group['params']) + + for g in optimizer.param_groups: + get_group_mapping(g) + return {k: param_mappings[id(p)] for k, p in model.named_parameters()} + + +def compute_optimizer_state_size(state: Dict[str, Any]) -> int: + size = 0 + for v in state.values(): + if isinstance(v, Tensor): + size += v.numel() * v.element_size() + return size + + +class ModelCheckpointSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, Tensor] = {} + self.buffer_size: int = 0 + + def append(self, key: str, tensor: Tensor) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {} + self.buffer_size = 0 + self.buffer[key] = tensor + self.buffer_size += tensor.numel() * tensor.element_size() + return retval + + def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: + shards = [] + for key, tensor in state_dict.items(): + shard = self.append(key, tensor) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer) > 0 else None + + +class OptimizerCheckpointSharder: + + def __init__(self, max_shard_size: int, param_groups: dict) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} + self.buffer_size: int = 0 + self.returned_first: bool = False + + def append(self, key: int, state: dict) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {'state': {}} + self.buffer_size = 0 + self.buffer['state'][key] = state + self.buffer_size += compute_optimizer_state_size(state) + return retval + + def extend(self, state_dict: Dict[str, dict]) -> List[dict]: + shards = [] + for key, state in state_dict['state'].items(): + shard = self.append(key, state) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer['state']) > 0 else None + + +def shard_checkpoint(max_shard_size: int, + model_state_dict: Dict[str, Tensor], + optimizer_state_dict: Optional[dict] = None, + param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: + has_optimizer: bool = False + if optimizer_state_dict is not None: + assert param_to_os is not None + os_to_param = {v: k for k, v in param_to_os.items()} + for os_key in optimizer_state_dict['state'].keys(): + assert os_key in os_to_param + assert os_to_param[os_key] in model_state_dict + has_optimizer = True + model_sharder = ModelCheckpointSharder(max_shard_size) + model_shards = model_sharder.extend(model_state_dict) + run_if_not_none(model_shards.append, model_sharder.complete()) + if not has_optimizer: + return model_shards, [] + optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) + optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) + run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) + return model_shards, optimizer_shards + + +def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: + os_to_param = {v: k for k, v in param_to_os.items()} + paired_os = {} + for idx, state in optimizer_state_dict['state'].items(): + paired_os[idx] = {} + p = model_state_dict[os_to_param[idx]] + for k, v in state.items(): + if isinstance(v, Tensor) and v.shape == p.shape: + paired_os[idx][k] = True + else: + paired_os[idx][k] = False + return paired_os + + +def build_checkpoints(max_size: int, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: + save_global = dist_meta is None + model_state_dict = model.state_dict() + optimizer_state_dict = optimizer.state_dict() if optimizer else None + meta = {'dist_meta': dist_meta} + if optimizer: + param_to_os = param_to_os or get_param_to_os(model, optimizer) + paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) + meta['param_to_os'] = param_to_os + meta['paired_os'] = paired_os + if not save_global and eliminate_replica: + # filter dp replicated params + model_state_dict = { + k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + if optimizer: + optimizer_state_dict['state'] = { + param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] + for k in model_state_dict.keys() + if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + meta['params'] = list(model_state_dict.keys()) + if len(model_state_dict) == 0: + warnings.warn('model state dict is empty, checkpoint is not saved') + return [], [], meta + model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, + param_to_os) + return model_checkpoints, optimizer_checkpoints, meta + + +def is_duplicated_list(list_: List[Any]) -> bool: + if len(list_) == 0: + return True + elem = list_[0] + for x in list_[1:]: + if x != elem: + return False + return True + + +def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: + for k, v in src_state.items(): + if k in dest_state: + old_v = dest_state[k] + if isinstance(old_v, Tensor): + old_v.copy_(v) + else: + dest_state[k] = v + + +def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: + assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] + state_dict = deepcopy(state_dict) + groups = optimizer.param_groups + saved_groups = state_dict['param_groups'] + idx_to_p: Dict[str, Parameter] = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in groups))) + } + missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) + unexpected_keys = [] + error_msgs = [] + for idx, state in state_dict['state'].items(): + if idx in idx_to_p: + old_state = optimizer.state[idx_to_p[idx]] + copy_optimizer_state(state, old_state) + else: + unexpected_keys.append(idx) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, + "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py new file mode 100644 index 000000000..4552accde --- /dev/null +++ b/colossalai/utils/checkpoint_io/writer.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Optional +from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME +import torch +import os + + +class CheckpointWriter(ABC): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__() + self.base_name = base_name + self.overwrite = overwrite + self.rank = rank + self.world_size = world_size + self.is_distributed = world_size > 1 + self.is_main_process = rank == 0 + + @abstractmethod + def write(self, name: str, state_dict: dict) -> None: + pass + + @abstractmethod + def save_model(self, model_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_meta(self, meta_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_others(self, kwargs: dict) -> None: + pass + + +class DiskCheckpointWriter(CheckpointWriter): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__(base_name, overwrite, rank, world_size) + if not os.path.exists(base_name): + os.makedirs(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + self.model_checkpoint_names = [] + self.optimizer_checkpoint_names = [] + self.is_meta_saved: bool = False + self._save_global_meta() + + def write(self, name: str, state_dict: dict) -> None: + path = os.path.join(self.base_name, name) + if os.path.exists(path) and not self.overwrite: + raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') + torch.save(state_dict, path) + + def _save_global_meta(self) -> None: + if self.is_main_process: + global_meta = {'meta': []} + if self.is_distributed: + for i in range(self.world_size): + global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) + else: + global_meta['meta'].append(META_CKPT_FILE_NAME) + self.write(GLOBAL_META_FILE_NAME, global_meta) + + def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: + checkpoint_name = base_name + if self.is_distributed: + checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') + if shard_idx is not None: + checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') + return checkpoint_name + + def save_model(self, model_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save model after saving meta' + name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) + self.write(name, model_checkpoint) + self.model_checkpoint_names.append(name) + + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' + name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) + self.write(name, optimizer_checkpoint) + self.optimizer_checkpoint_names.append(name) + + def save_meta(self, meta_checkpoint: dict) -> None: + if len(self.model_checkpoint_names) > 0: + meta_checkpoint['model'] = self.model_checkpoint_names + if len(self.optimizer_checkpoint_names) > 0: + meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names + self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) + self.is_meta_saved = True + + def save_others(self, kwargs: dict) -> None: + if self.is_main_process: + self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index a52c25530..7575fa292 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,32 +1,31 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import functools import os import random import socket +from collections import defaultdict +from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Union, Dict, Optional -import functools +from typing import Callable, Dict, List, Optional, Union import torch +import torch.distributed as dist from torch._six import inf from torch.nn.parameter import Parameter -try: - import colossal_C -except: - pass - -from contextlib import contextmanager - -import torch.distributed as dist -from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES) +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env +from colossalai.tensor import ColoParameter, ProcessGroup + from .multi_tensor_apply import multi_tensor_applier -from colossalai.tensor import ColoParameter, ProcessGroup -from collections import defaultdict +try: + from colossalai._C import fused_optim +except: + fused_optim = None def print_rank_0(msg: str, logger=None): @@ -128,11 +127,18 @@ def is_model_parallel_parameter(p): def _calc_l2_norm(grads): + # we should not + global fused_optim + + if fused_optim is None: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - colossal_C.multi_tensor_l2norm, + fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm @@ -269,7 +275,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: cpu_grads.append(p.grad.detach()) if len(cuda_grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], + clip_coef) for g in cpu_grads: g.mul_(clip_coef) @@ -395,7 +402,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if enable_cuda_kernels: grads = [p.grad.detach() for p in params] dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) else: for p in params: p.grad.detach().mul_(clip_coeff) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 3824d27f6..93c91e099 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,10 +1,13 @@ -from .utils import InsertPostInitMethodToModuleSubClasses +from typing import Any, Dict, Iterator, Optional, Tuple, Union + import torch -from colossalai.tensor import ColoTensor, ColoParameter -from colossalai.nn.parallel.layers import register_colo_module, \ - ColoLinear, ColoEmbedding from torch import nn -from typing import Iterator, Tuple, Union + +from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup + +from .utils import InsertPostInitMethodToModuleSubClasses + # find named_params includes replica @@ -23,6 +26,39 @@ def _named_params_with_replica( yield name, val +def _convert_to_coloparam(param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None) -> ColoParameter: + + if isinstance(param, ColoParameter): + return param + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + # param is the global tensor. + + if param.device.type == "meta": + colo_param = ColoParameter(param, requires_grad=requires_grad) + else: + colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) + + + # if default_shard_plan exists, shard the param during initialization. + # This can reduce the model size after initialization. + # NOTE() embedding usually can not be correctly sharded. So I use except to handle + # the param that can not be sharded by the default plan + if default_pg is not None: + colo_param.set_process_group(default_pg) + + if default_dist_spec is not None: + try: + colo_param.set_dist_spec(default_dist_spec) + except: + pass + return colo_param + + def ColoModulize(module): """ Replacing the parameters() and named_parameters() with our customized ones @@ -34,20 +70,24 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, - lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float): + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): """ Args: - lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. - device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu'). + device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). + dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float. + default_pg (ProcessGroup): the default process group for all initialized parameters. + default_dist_spec: the default distributed specifications. """ super().__init__() - self._lazy_memory_allocate = lazy_memory_allocate self._device = device self._dtype = dtype self._register_colo_modules() + self._default_pg = default_pg + self._default_dist_spec = default_dist_spec def _register_colo_modules(self): register_colo_module(torch.nn.Linear, ColoLinear()) @@ -61,10 +101,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): The function to call at the end of the constructor of each module. FIXME(fjr) The module may be passed to this function multiple times? """ - - if hasattr(module, '_colo_visited'): - return - name_list = [] for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): @@ -87,17 +123,74 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if param in replaced_tensors: colo_param = replaced_tensors[param] else: - save_torch_payload = True if not self._lazy_memory_allocate else False - # detaching tensor is necessary for optimizers. - requires_grad = param.requires_grad - # TODO(jiaruifang) we initialize a Default PG memory - colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype), - requires_grad=requires_grad) - # add mapping record + colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, + self._default_dist_spec) replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) colo_param.shared_param_modules.append(submodule) + + meta_param_flag = 0 + meta_buffer_flag = 0 + for param in module.parameters(): + if param.device.type=="meta": + meta_param_flag = 1 + if meta_param_flag == 1 and param.device.type!="meta": + raise ValueError("Meta parameters and valued parameters can not be in the same model") + + for buffer in module.buffers(): + if buffer.device.type=="meta": + meta_buffer_flag = 1 + if meta_buffer_flag == 1 and buffer.device.type!="meta": + raise ValueError("Meta buffers and valued buffers can not be in the same model") + + if meta_param_flag==1 and meta_buffer_flag==1: + pass + elif meta_buffer_flag==0 and meta_param_flag==1: + for name, buf in module.named_buffers(): + module._buffers[name] = module._buffers[name].to(device=self._device) + elif meta_param_flag==0 and meta_buffer_flag==1: + for name, param in module.named_parameters(): + module._parameters[name] = module._parameters[name].to(device=self._device) + else: + module.to(self._device) + - module.to(self._device) - ColoModulize(module) +def post_process_colo_init_ctx(model: torch.nn.Module, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """post_process_colo_init_ctx + + This function is called after `ColoInitContext`. + + Args: + model (torch.nn.module): the model + device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu'). + dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float. + default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group. + default_dist_spec (Any, optional): default dist spec of params. Defaults to None. + + Raises: + RuntimeError: raise error if + """ + + torch_params = [] + for n, p in model.named_parameters(): + if not isinstance(p, ColoParameter): + # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") + torch_params.append((n, p)) + + for (n, param) in torch_params: + name_list = n.split('.') + module = model + for i in range(len(name_list) - 1): + module = module._modules[name_list[i]] + delattr(module, name_list[-1]) + setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) + + del torch_params + for n, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + raise RuntimeError diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index ed94429d4..cf05f9660 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -1,23 +1,24 @@ #!/usr/bin/env python # coding: utf-8 +import inspect +import types +from typing import Callable, List + import torch import torch.nn as nn -from colossalai.tensor import ColoParameter, ColoTensor -import types -import inspect -from typing import List, Callable +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.utils.model.utils import substitute_init_recursively class LazyInitContext(): """ - A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor + A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor initialization functions for lazy initialization Note: - This API is only experimental and subject to future changes. + This API is only experimental and subject to future changes. Usage: with LazyInitContext() as ctx: @@ -30,19 +31,20 @@ class LazyInitContext(): # initialize weights ctx.lazy_init_parameters(model) - # make sure the weight is not a meta tensor + # make sure the weight is not a meta tensor # and initialized correctly assert not model.weight.is_meta and torch.all(model.weight == 0) Args: - to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. - extra_torch_tensor_func (List[str]): extra torch tensor functions related + to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This + argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. + extra_torch_tensor_func (List[str]): extra torch tensor functions related to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. """ tensor_set_value_func = ['zero_', 'fill_'] - def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): + def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): # TODO: hijack the torch constructor functions as well self._to_meta = to_meta self._intercepted_nn_init_func_cache = {} @@ -212,18 +214,19 @@ class LazyInitContext(): materialized_tensor = torch.empty_like(tensor, device=device) # if this tensor is a meta tensor, it must have an init function assert tensor in self._intercepted_nn_init_func_cache - tensor = materialized_tensor + else: + materialized_tensor = tensor # apply init function if tensor in self._intercepted_nn_init_func_cache: init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] - init_func(tensor, *args, **kwargs) + init_func(materialized_tensor, *args, **kwargs) # convert it to ColoTensor or ColoParameter if is_param: - tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) + tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) else: - tensor = ColoTensor.from_torch_tensor(tensor) + tensor = ColoTensor.from_torch_tensor(materialized_tensor) # override the original tensor with torch.no_grad(): diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index 4e847f17b..2b6de5fe1 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -14,7 +14,6 @@ class MultiTensorApply(object): def __init__(self, chunk_size): try: - import colossal_C MultiTensorApply.available = True self.chunk_size = chunk_size except ImportError as err: diff --git a/colossalai/utils/profiler/legacy/mem_profiler.py b/colossalai/utils/profiler/legacy/mem_profiler.py deleted file mode 100644 index f80f6ecf5..000000000 --- a/colossalai/utils/profiler/legacy/mem_profiler.py +++ /dev/null @@ -1,48 +0,0 @@ -from pathlib import Path -from typing import Union -from colossalai.engine import Engine -from torch.utils.tensorboard import SummaryWriter -from colossalai.gemini.ophooks import MemTracerOpHook -from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler - - -class MemProfiler(BaseProfiler): - """Wraper of MemOpHook, used to show GPU memory usage through each iteration - - To use this profiler, you need to pass an `engine` instance. And the usage is same like - CommProfiler. - - Usage:: - - mm_prof = MemProfiler(engine) - with ProfilerContext([mm_prof]) as prof: - writer = SummaryWriter("mem") - engine.train() - ... - prof.to_file("./log") - prof.to_tensorboard(writer) - - """ - - def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None: - super().__init__(profiler_name="MemoryProfiler", priority=0) - self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate) - self._engine = engine - - def enable(self) -> None: - self._engine.add_hook(self._mem_tracer) - - def disable(self) -> None: - self._engine.remove_hook(self._mem_tracer) - - def to_tensorboard(self, writer: SummaryWriter) -> None: - stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] - for info, i in enumerate(stats): - writer.add_scalar("memory_usage/GPU", info, i) - - def to_file(self, data_file: Path) -> None: - self._mem_tracer.save_results(data_file) - - def show(self) -> None: - stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] - print(stats) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 0e320f912..098ccbb45 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -2,10 +2,12 @@ from typing import Tuple import torch import torch.nn as nn + from colossalai.logging import get_dist_logger from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 -from .zero_optimizer import ZeroOptimizer +from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 + +from ..nn.optimizer.zero_optimizer import ZeroOptimizer def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, @@ -36,4 +38,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model return zero_model, zero_optimizer -__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] +__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 7d5cfdae0..ae3a61998 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,31 +1,39 @@ import functools -from collections import OrderedDict -from typing import Any, Optional, Iterator, Tuple -from copy import deepcopy import itertools +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Iterator, Optional, Tuple + import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector from colossalai.gemini.ophooks import register_ophooks_recursively -from colossalai.zero.utils import ZeroHook from colossalai.gemini.paramhooks import BaseParamHookMgr +from colossalai.gemini.stateful_tensor import TensorState +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device, disposable -from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.utils import disposable, get_current_device from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter -from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy +from colossalai.zero.utils import ZeroHook -from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, - get_gradient_predivide_factor) +from ._utils import ( + cast_float_arguments, + cast_tensor_to_fp16, + cast_tensor_to_fp32, + chunk_and_pad, + free_storage, + get_gradient_predivide_factor, +) try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX @@ -49,7 +57,7 @@ class ShardedModelV2(nn.Module): module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. - reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. @@ -60,10 +68,10 @@ class ShardedModelV2(nn.Module): Note that 'auto' policy can only work well when no other processes use CUDA during your training. Defaults to 'cuda'. gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. - reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. - Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. - In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). - We find that PyTorch's optimizers don't support mixed precision, + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. """ @@ -198,15 +206,14 @@ class ShardedModelV2(nn.Module): f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') f.write('CUDA model data (GB)\n') - f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) f.write('\n') f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB'))) + f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) + f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) f.write('\n') - def _pre_forward_operations(self): + def _pre_forward_operations(self, *args): # the operation will affect the memory tracer behavior in ZeroHook if self._memstats_collector: self._start_collect_memstats() @@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - self._pre_forward_operations() + self._pre_forward_operations(*args) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) self._post_forward_operations() @@ -248,8 +255,8 @@ class ShardedModelV2(nn.Module): # the way to calculate margin space is based on the assumption that # model data is fixed in cuda during training. # cuda margin space can be used to store OS. - self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max( - self._memstats_collector.overall_mem_stats('cuda')) + self._cuda_margin_space = colo_device_memory_capacity( + get_current_device()) - self._memstats_collector._memstats.max_overall_cuda @torch.no_grad() def _post_backward_operations(self) -> None: diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py index b71a70aef..30c26fb75 100644 --- a/colossalai/zero/sharded_optim/__init__.py +++ b/colossalai/zero/sharded_optim/__init__.py @@ -1,3 +1,4 @@ +from .low_level_optim import LowLevelZeroOptimizer from .sharded_optim_v2 import ShardedOptimizerV2 -__all__ = ['ShardedOptimizerV2'] +__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py index 49cf21969..9a839a570 100644 --- a/colossalai/zero/sharded_optim/_utils.py +++ b/colossalai/zero/sharded_optim/_utils.py @@ -1,11 +1,13 @@ import math + import torch +import torch.distributed as dist from torch._six import inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.core import global_context as gpc + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.utils import is_model_parallel_parameter -import torch.distributed as dist def flatten(input_): @@ -99,19 +101,24 @@ def split_half_float_double(tensor_list): return buckets -def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA): +def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA): """ Reduce the tensor in the data parallel process group :param tensor: A tensor object to reduce/all-reduce :param dtype: The data type used in communication :param dst_rank: The source rank for reduce. If dst_rank is None, + :param parallel_mode: Communication parallel mode all-reduce will be used instead of reduce. Default is None. :type tensor: torch.Tensor - :type dtype: torch.dtype + :type dtype: torch.dtype, optional :type dst_rank: int, optional + :type parallel_mode: ParallelMode, optional """ + # use the original dtype + if dtype is None: + dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce if tensor.dtype != dtype: @@ -139,6 +146,7 @@ def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA) local_rank = gpc.get_local_rank(parallel_mode) if use_all_reduce or dst_rank == local_rank: tensor.copy_(tensor_to_reduce) + return tensor @@ -238,7 +246,7 @@ def sync_param(flat_tensor, tensor_list): Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a new tensor is created. Thus, the flat tensor and original tensor list do not - share the same memory space. This function will update the tensor list so that + share the same memory space. This function will update the tensor list so that they point to the same value. :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/sharded_optim/bookkeeping/__init__.py new file mode 100644 index 000000000..7bcacfabf --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/__init__.py @@ -0,0 +1,6 @@ +from .bucket_store import BucketStore +from .gradient_store import GradientStore +from .parameter_store import ParameterStore +from .tensor_bucket import TensorBucket + +__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/sharded_optim/bookkeeping/base_store.py new file mode 100644 index 000000000..d4436acaa --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/base_store.py @@ -0,0 +1,17 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class BaseStore: + + def __init__(self, dp_parallel_mode=ParallelMode.DATA): + self._world_size = gpc.get_world_size(dp_parallel_mode) + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + + @property + def world_size(self): + return self._world_size + + @property + def local_rank(self): + return self._local_rank diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py new file mode 100644 index 000000000..0f2b1bb88 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py @@ -0,0 +1,44 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +from .base_store import BaseStore + + +class BucketStore(BaseStore): + + def __init__(self, dp_parallel_mode): + super().__init__(dp_parallel_mode) + self._grads = dict() + self._params = dict() + self._num_elements_in_bucket = dict() + + self.reset() + + def num_elements_in_bucket(self, reduce_rank: int = None): + return self._num_elements_in_bucket[reduce_rank] + + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): + self._num_elements_in_bucket[reduce_rank] += num_elements + + def add_grad(self, tensor, reduce_rank: int = None): + self._grads[reduce_rank].append(tensor) + + def add_param(self, tensor, reduce_rank: int = None): + self._params[reduce_rank].append(tensor) + + def reset(self): + keys = [None] + list(range(self._world_size)) + self._grads = {rank: [] for rank in keys} + self._params = {rank: [] for rank in keys} + self._num_elements_in_bucket = {rank: 0 for rank in keys} + + def reset_by_rank(self, reduce_rank=None): + self._grads[reduce_rank] = [] + self._params[reduce_rank] = [] + self._num_elements_in_bucket[reduce_rank] = 0 + + def get_grad(self, reduce_rank: int = None): + return self._grads[reduce_rank] + + def get_param(self, reduce_rank: int = None): + return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py new file mode 100644 index 000000000..8a9128a18 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -0,0 +1,66 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class GradientStore(BaseStore): + + def __init__(self, *args): + super().__init__(*args) + # bookkeeping data structures + self._averaged_gradients = dict() + + # for backward reduction hooks + self._grad_acc_objs = [] + + def add_accumulate_grad_object(self, obj): + """ + Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not + be attached successfully. + + :param obj: An object of :class:`AccumulateGrad` class + :type obj: :class:`AccumulateGrad` + """ + + self._grad_acc_objs.append(obj) + + def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: + """ + Return average gradients of a parameter group + + :param group_id: The index of parameter group + :type group_id: int + + :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. + :rtype: List[torch.Tensor] + """ + + return self._averaged_gradients[group_id] + + def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + """ + Append an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor: torch.Tensor + + """ + + if group_id in self._averaged_gradients: + self._averaged_gradients[group_id].append(tensor) + else: + self._averaged_gradients[group_id] = [tensor] + + def reset_average_gradients_by_group(self, group_id: int) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + + :param group_id: The index of a parameter group + :type group_id: int + """ + + self._averaged_gradients[group_id] = [] diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py new file mode 100644 index 000000000..09ebaaf99 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py @@ -0,0 +1,96 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class ParameterStore(BaseStore): + + def __init__(self, dp_paralle_mode): + super().__init__(dp_paralle_mode) + # param partitioning data structures + self._fp16_param_to_rank = dict() + self._rank_groupid_to_fp16_param_list = dict() + self._rank_group_id_to_flat_fp16_param = dict() + + # param reduction data structures + self._is_param_reduced = dict() + self._reduced_param = [] + + def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: + """ + Set the mapping between parameter to rank, each parameter should be owned by a rank. + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + :param rank: The rank of which the process is responsible for updating the parameter + :type rank: int + """ + + self._fp16_param_to_rank[tensor] = rank + + def get_param_rank(self, tensor: Tensor) -> int: + """ + Gives the rank which the parameter belongs to + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + """ + return self._fp16_param_to_rank[tensor] + + def belongs_to_current_rank(self, tensor) -> bool: + """ + Check whether a parameter is supposed to be updated by the process of the current rank + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + + :return: True if the parameter should be updated by the current rank. Otherwise false. + :rtype: bool + """ + + tensor_rank = self._fp16_param_to_rank[tensor] + return tensor_rank == self._local_rank + + def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: + if rank not in self._rank_groupid_to_fp16_param_list: + self._rank_groupid_to_fp16_param_list[rank] = dict() + + if group_id not in self._rank_groupid_to_fp16_param_list[rank]: + self._rank_groupid_to_fp16_param_list[rank][group_id] = [] + + self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list) + + def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]: + return self._rank_groupid_to_fp16_param_list[rank][group_id] + + def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None: + if rank not in self._rank_group_id_to_flat_fp16_param: + self._rank_group_id_to_flat_fp16_param[rank] = dict() + + self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor + + def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor: + return self._rank_group_id_to_flat_fp16_param[rank][group_id] + + def is_param_reduced(self, tensor): + return self._is_param_reduced[tensor] + + def set_param_reduction_state(self, tensor, state): + self._is_param_reduced[tensor] = state + + def get_param_reduction_states(self): + return self._is_param_reduced + + def reset_previous_reduced_params(self): + self._reduced_param = [] + + def add_previous_reduced_param(self, tensor): + self._reduced_param.append(tensor) + + def clear_grads_of_previous_reduced_params(self): + if len(self._reduced_param) > 0: + for param in self._reduced_param: + param.grad = None + self.reset_previous_reduced_params() diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py new file mode 100644 index 000000000..b32816a04 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py @@ -0,0 +1,53 @@ +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class TensorBucket: + + def __init__(self, size): + self._max_size = size + self._current_size = 0 + self._bucket = [] + + @property + def max_size(self): + return self._max_size + + @property + def current_size(self): + return self._current_size + + def is_full_or_oversized(self): + return self._current_size >= self._max_size + + def is_empty(self): + return len(self._bucket) == 0 + + def add_to_bucket(self, tensor, allow_oversize=False): + tensor_size = tensor.numel() + + if not allow_oversize and self.will_exceed_max_size(tensor_size): + msg = f"The param bucket max size {self._max_size} is exceeded" \ + + f"by tensor (size {tensor_size})" + raise RuntimeError(msg) + + self._bucket.append(tensor) + self._current_size += tensor_size + + def will_exceed_max_size(self, tensor_size): + expected_size = self._current_size + tensor_size + return expected_size > self._max_size + + def get_bucket(self): + return self._bucket + + def empty(self): + self._bucket = [] + self._size = 0 + + def flatten(self): + return _flatten_dense_tensors(self._bucket) + + def unflatten_and_copy(self, flat_tensor): + unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py new file mode 100644 index 000000000..c437ac549 --- /dev/null +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -0,0 +1,584 @@ +from functools import partial +from itertools import groupby + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.cuda import get_current_device + +from ._utils import ( + calculate_global_norm_from_list, + compute_norm, + flatten, + get_grad_accumulate_object, + has_inf_or_nan, + reduce_tensor, + release_param_grad, + split_half_float_double, + sync_param, +) +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket + + +class LowLevelZeroOptimizer(ColossalaiOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2. + """ + + def __init__( + self, + optimizer: Optimizer, + + # grad scaler config + initial_scale=2**16, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=2000, + hysteresis=2, + max_scale: int = 2**24, + + # grad clipping + clip_grad_norm=0.0, + verbose=False, + + # communication + reduce_bucket_size=1024 * 1024, + communication_dtype=None, + overlap_communication=False, + + # stage 2 + partition_grad=False, + dp_parallel_mode=ParallelMode.DATA, + mp_parallel_mode=ParallelMode.MODEL, + + # cpu offload + cpu_offload=False, + + # forced dtype + forced_dtype=None): + + # TODO: add support for + # 1. fp16 master weights + # 2. contiguous gradients + # 3. cpu offload + # 4. support when some parameters requires_grad = False + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]['params'][0].dtype + self._logger = get_dist_logger() + self._verbose = verbose + + # stage 2 + self._partition_grads = partition_grad + + # cpu_offload + self._cpu_offload = cpu_offload + + # get process groups + self._dp_parallel_mode = dp_parallel_mode + self._mp_parallel_mode = mp_parallel_mode + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + self._world_size = gpc.get_world_size(dp_parallel_mode) + + self._dp_group = gpc.get_group(dp_parallel_mode) + if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: + self._mp_group = gpc.get_group(mp_parallel_mode) + else: + self._mp_group = None + + # fp16 and fp32 params for mixed precision training + self._fp16_param_groups = dict() + self._fp32_flat_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # gradient scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + verbose=verbose) + self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) + + # gradient clipping + self._clip_grad_norm = clip_grad_norm + + if forced_dtype: + for group in self.optim.param_groups: + group_params = group['params'] + for param in group_params: + param.data = param.data.to(forced_dtype) + self._dtype = forced_dtype + + # check argument conflict + self._sanity_checks() + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + self._param_store = ParameterStore(self._dp_parallel_mode) + self._grad_store = GradientStore(self._dp_parallel_mode) + self._bucket_store = BucketStore(self._dp_parallel_mode) + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self.optim.param_groups): + group_params = param_group['params'] + + # add the fp16 params to fp16_param_groups for bookkeeping + self._fp16_param_groups[group_id] = group_params + + # assign parameters to ranks + # the params in the list are sorted + params_per_rank = self._partition_param_list(group_params) + + # store the mapping between param to rank + # each param should belong to only one rank + for rank, params in enumerate(params_per_rank): + self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) + for param in params: + self._param_store.set_param_to_rank(param, rank) + + # move to cpu to make room to create the flat tensor + # move_tensor(params, device='cpu') + for param in group_params: + param.data = param.data.cpu() + + # flatten the reordered tensors + for rank in range(self._world_size): + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + with torch.no_grad(): + flat_tensor = flatten(tensor_list) + flat_tensor = flat_tensor.data.cuda() + self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) + + # sync parameters + for rank in range(self._world_size): + flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id) + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) + + # create a copy of fp32 weights of the parameters for which this rank is responsible + fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id) + fp32_flat_current_rank = fp16_flat_current_rank.float() + device = 'cpu' if self._cpu_offload else get_current_device() + fp32_flat_current_rank = fp32_flat_current_rank.to(device) + fp32_flat_current_rank.requires_grad = True + self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group['params'] = [fp32_flat_current_rank] + + # set reduction state + for param in self._fp16_param_groups[group_id]: + self._param_store.set_param_reduction_state(param, False) + + # intialize communication stream for + # communication-compuation overlapping + if self._overlap_communication: + self._comm_stream = torch.cuda.Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() + + @property + def dtype(self): + return self._dtype + + @property + def loss_scale(self): + return self.grad_scaler.scale + + @property + def num_param_groups(self): + return len(self._fp16_param_groups) + + def _partition_param_list(self, param_list): + params_per_rank = [[] for _ in range(self._world_size)] + numel_per_rank = [0 for _ in range(self._world_size)] + + # partititon the parameters in a greedy fashion + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for param in sorted_params: + # allocate this parameter to the rank with + # the smallest numel for load balancing purpose + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + numel_per_rank[rank_to_go] += param.numel() + + if self._verbose: + self._logger.info(f'Number of elements on ranks: {numel_per_rank}', + ranks=[0], + parallel_mode=self._dp_parallel_mode) + return params_per_rank + + def _sanity_checks(self): + assert torch.cuda.is_available(), 'CUDA is required' + for param_group in self.optim.param_groups: + group_params = param_group['params'] + for param in group_params: + assert param.dtype == self._dtype, \ + f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + ########################################################### + # Backward Reduction Hook + ########################################################### + + def _attach_reduction_hook(self): + # we iterate over the fp16 params + # on each param, we register a hook to its AccumulateGrad object + for group_id in range(self.num_param_groups): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.requires_grad: + # determines the reduction destionation rank + # this is only valid for stage 2 + # dst_rank = None means using all-reduce + # else using reduce + if self._partition_grads: + reduce_rank = self._param_store.get_param_rank(param) + else: + reduce_rank = None + + def _define_and_attach(param, reduce_rank): + # get the AccumulateGrad object of the param itself + accum_grad_obj = get_grad_accumulate_object(param) + self._grad_store.add_accumulate_grad_object(accum_grad_obj) + + reduction_func = partial(self._reduce_and_remove_grads_by_bucket, + param=param, + reduce_rank=reduce_rank) + + # define hook + # NOT IMPORTANT BUT GOOD TO KNOW: + # args here is not grad, but allow_unreacable and accumulate_grad + def reduce_grad_hook(*args): + reduction_func() + + accum_grad_obj.register_hook(reduce_grad_hook) + + _define_and_attach(param, reduce_rank) + + def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._reduce_grads_in_bucket(reduce_rank) + + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_grad(param.grad, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) + + def _reduce_grads_in_bucket(self, reduce_rank=None): + # reduce grads + self._reduce_grads_by_rank(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + + # use communication stream if overlapping + # communication with computation + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + + for param in params_in_bucket: + # the is_param_reduced flag should be False showing that + # this param is not reduced before calling self._reduce_grads_by_rank + is_param_reduced = self._param_store.is_param_reduced(param) + + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # update the flag + self._param_store.set_param_reduction_state(param, True) + + # if partition grads = True + # we do not keep the gradient after reduction + if self._partition_grads and not self._param_store.belongs_to_current_rank(param): + if self._overlap_communication: + # we need to keep this gradient for now as reduction may + # be completed yet since it is using a different cuda stream + self._param_store.add_previous_reduced_param(param) + else: + param.grad = None + + self._bucket_store.reset_by_rank(reduce_rank) + + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_half_float_double(grads) + + for tensor_list in grad_buckets_by_dtype: + self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) + + ############################## + # Reduction Utility Function # + ############################## + def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) + + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) + + if param_bucket.is_full_or_oversized(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() + + if not param_bucket.is_empty(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduced_flat = reduce_tensor(tensor=flat, + dtype=self._communication_dtype, + dst_rank=reduce_rank, + parallel_mode=self._dp_parallel_mode) + + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=False): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + # finish gradient reduction + if not self._partition_grads: + self._reduce_grad_stage1() + else: + # TODO: support async comm in reduce + self._reduce_grad_stage2() + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + for group_id, param_group in self._fp16_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, 'closure is not supported by step()' + + # check for overflow + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + # update loss scale if overflow occurs + if found_inf: + self._grad_store._averaged_gradients = dict() + self.zero_grad() + return + + # copy the grad of fp16 param to fp32 param + single_grad_partition_groups = [] + norm_groups = [] + + for group_id in range(self.num_param_groups): + # compute norm + norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], + params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, + rank=self._local_rank), + dp_group=self._dp_group, + mp_group=self._mp_group) + norm_groups.append(norm_group) + + # create flat gradient for the flat fp32 params + fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_fp16_avg_grads = flatten(fp16_avg_grads) + + dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype + flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + + param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape + assert param_shape == flat_fp32_avg_grads.shape, \ + f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}' + + single_grad_partition_groups.append(flat_fp32_avg_grads) + device = self._fp32_flat_param_groups_of_current_rank[group_id].device + self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) + self._grad_store._averaged_gradients[group_id] = [] + self._grad_store._averaged_gradients[group_id] = [] + + # unscale and clip grads + global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + + # update the parameters + self.optim.step() + # release the fp32 grad + release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + + # update fp16 partition updated by the current rank + for group_id in range(len(self._fp16_param_groups)): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp16_param.data.copy_(fp32_param) + + # broadcast the updated model weights + handles = [] + for group_id in range(self.num_param_groups): + for rank in range(self._world_size): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True) + handles.append(handle) + + for handle in handles: + handle.wait() + + ################## + # FP16 Utilities # + ################## + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group_id in range(len(self._fp16_param_groups)): + for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group) + + # all-reduce over model parallel group + if self._mp_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group) + + if self._found_overflow.item() > 0: + return True + else: + return False + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + + if self._clip_grad_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + grad.data.mul_(1. / combined_scale) + + ############################ + # Gradient Synchronization # + ############################ + + def sync_grad(self): + # update param already reduced flag + reduction_states = self._param_store.get_param_reduction_states() + for tensor, state in reduction_states.items(): + reduction_states[tensor] = False + + # accumulate gradient + avg_gradients = self._grad_store._averaged_gradients + for group_id in range(self.num_param_groups): + param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) + + if group_id not in avg_gradients: + avg_gradients[group_id] = [] + + param_idx = 0 + for param in param_group: + if param.grad is not None: + if len(avg_gradients[group_id]) == param_idx: + avg_gradients[group_id].append(param.grad) + else: + avg_gradients[group_id][param_idx].add_(param.grad) + param_idx += 1 + + # the gradients needed are stored in the avg_gradients buffer + # thus, can clear this + self.zero_grad() + + def _reduce_grad_stage1(self): + # if not overlapping communication (no reduction hook is attached) + # we need to manually reduce these gradients + if not self._overlap_communication: + for group_id in range(len(self._fp16_param_groups)): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.grad is not None: + self._reduce_and_remove_grads_by_bucket(param) + + # we need to reduce the gradients + # left in the communication bucket + self._reduce_grads_in_bucket() + + def _reduce_grad_stage2(self): + # when partition_grads is True, reduction hooks + # are attached in the __init__ function, so we + # only need to reduce the gradients + # left in the communication bucket + for reduce_rank in range(self._world_size): + self._reduce_grads_in_bucket(reduce_rank) diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/gemini_hook.py similarity index 91% rename from colossalai/zero/utils/zero_hook_v2.py rename to colossalai/zero/utils/gemini_hook.py index 584a0fe37..35569c717 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/gemini_hook.py @@ -1,11 +1,13 @@ -import torch -from colossalai.tensor.param_op_hook import ParamOpHook -from colossalai.gemini import TensorState -from enum import Enum -from typing import List from contextlib import contextmanager +from enum import Enum from functools import partial +from typing import List + +import torch + +from colossalai.gemini import TensorState from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor.param_op_hook import ColoParamOpHook class TrainingPhase(Enum): @@ -13,7 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class ZeROHookV2(ParamOpHook): +class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() @@ -30,7 +32,9 @@ class ZeROHookV2(ParamOpHook): self._gemini_manager.adjust_layout(chunks) for chunk in chunks: self._chunk_manager.access_chunk(chunk) - self._gemini_manager.sample_model_data() + + # record cuda model data of the current OP + self._gemini_manager.record_model_data_volume() def post_op(self, params): params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index 189d1ad2d..87bf2c0f5 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -2,23 +2,22 @@ from typing import Optional import torch import torch.distributed as dist + +from colossalai.gemini.memory_tracer import MemStatsCollector +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.gemini.stateful_tensor import TensorState +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.logging import get_dist_logger from colossalai.registry import OPHOOKS - from colossalai.utils import get_current_device - from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.gemini.ophooks import BaseOpHook - -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.stateful_tensor import TensorState @OPHOOKS.register_module class ZeroHook(BaseOpHook): """ A hook to process sharded param for ZeRO method. + Warning: this class has been deprecated after version 0.1.12 """ def __init__(self, @@ -68,7 +67,7 @@ class ZeroHook(BaseOpHook): # record model data statistics if self._memstarts_collector: - self._memstarts_collector.sample_model_data() + self._memstarts_collector.record_model_data_volume() def pre_fwd_exec(self, module: torch.nn.Module, *args): self.adjust_module_data(module) diff --git a/docker/Dockerfile b/docker/Dockerfile index 4b55dc1eb..0faba17b9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,17 +1,18 @@ FROM hpcaitech/cuda-conda:11.3 # install torch -RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch # install apex RUN git clone https://github.com/NVIDIA/apex && \ cd apex && \ - pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + pip install packaging && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ # install colossalai RUN git clone https://github.com/hpcaitech/ColossalAI.git \ - && cd ./ColossalAI \ - && pip install -v --no-cache-dir . + && cd ./ColossalAI \ + && CUDA_EXT=1 pip install -v --no-cache-dir . # install titans RUN pip install --no-cache-dir titans diff --git a/docs/colossalai/colossalai.pipeline.layer_sepc.rst b/docs/colossalai/colossalai.pipeline.layer_sepc.rst index 0ff6a83c2..156660b5c 100644 --- a/docs/colossalai/colossalai.pipeline.layer_sepc.rst +++ b/docs/colossalai/colossalai.pipeline.layer_sepc.rst @@ -1,5 +1,5 @@ colossalai.pipeline.layer\_sepc =============================== -.. automodule:: colossalai.pipeline.layer_sepc +.. automodule:: colossalai.pipeline.layer_spec :members: diff --git a/docs/colossalai/colossalai.pipeline.rst b/docs/colossalai/colossalai.pipeline.rst index adaebea2d..6f7652d49 100644 --- a/docs/colossalai/colossalai.pipeline.rst +++ b/docs/colossalai/colossalai.pipeline.rst @@ -8,6 +8,6 @@ colossalai.pipeline .. toctree:: :maxdepth: 2 - colossalai.pipeline.layer_sepc + colossalai.pipeline.layer_spec colossalai.pipeline.pipelinable colossalai.pipeline.utils diff --git a/docs/colossalai/colossalai.zero.utils.rst b/docs/colossalai/colossalai.zero.utils.rst index 15cf4d70d..50ee9071e 100644 --- a/docs/colossalai/colossalai.zero.utils.rst +++ b/docs/colossalai/colossalai.zero.utils.rst @@ -9,4 +9,4 @@ colossalai.zero.utils :maxdepth: 2 colossalai.zero.utils.zero_hook - colossalai.zero.utils.zero_hook_v2 + colossalai.zero.utils.gemini_hook diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst index 6c9af62f1..e6d6673af 100644 --- a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst +++ b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst @@ -1,5 +1,5 @@ colossalai.zero.utils.zero\_hook\_v2 ==================================== -.. automodule:: colossalai.zero.utils.zero_hook_v2 +.. automodule:: colossalai.zero.utils.gemini_hook :members: diff --git a/docs/requirements.txt b/docs/requirements.txt index ae216364c..2b3b1a25b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,5 @@ tensorboard -deepspeed apex sphinx sphinx-rtd-theme -myst-parser \ No newline at end of file +myst-parser diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..53ab0896d --- /dev/null +++ b/examples/README.md @@ -0,0 +1,28 @@ +## Examples folder document + +## Table of Contents + + +## Example folder description + +This folder provides several examples using colossalai. The images folder includes model like diffusion, dreambooth and vit. The language folder includes gpt, opt, palm and roberta. The tutorial folder is for concept illustration, such as auto-parallel, hybrid-parallel and so on. + + +## Integrate Your Example With System Testing + +For example code contributor, to meet the expectation and test your code automatically using github workflow function, here are several steps: + + +- (must) Have a test_ci.sh file in the folder like shown below in 'File Structure Chart' +- The dataset should be located in the company's machine and can be announced using environment variable and thus no need for a separate terminal command. +- The model parameters should be small to allow fast testing. +- File Structure Chart + + └─examples + └─images + └─vit + └─requirements.txt + └─test_ci.sh diff --git a/examples/images/diffusion/LICENSE b/examples/images/diffusion/LICENSE new file mode 100644 index 000000000..0e609df0d --- /dev/null +++ b/examples/images/diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md new file mode 100644 index 000000000..abb1d24c0 --- /dev/null +++ b/examples/images/diffusion/README.md @@ -0,0 +1,248 @@ +# ColoDiffusion: Stable Diffusion with Colossal-AI + + +Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). + +

+ +

+ +- [Training](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce Stable Diffusion memory consumption by up to 5.6x and hardware cost by up to 46x (from A100 to RTX3060). + +

+ +

+ + +- [DreamBooth Fine-tuning](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/dreambooth): Personalize your model using just 3-5 images of the desired subject. + +

+ +

+ + +- [Inference](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion): Reduce inference GPU memory consumption by 2.5x. + + +More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0). + +## Installation + +### Option #1: install from source +#### Step 1: Requirements + +A suitable [conda](https://conda.io/) environment named `ldm` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` + +##### Step 2: install lightning + +Install Lightning version later than 2022.01.04. We suggest you install lightning from source. + +https://github.com/Lightning-AI/lightning.git + + +##### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website + +For example, you can install v0.1.12 from our official website. + +``` +pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Option #2: Use Docker + +To use the stable diffusion Docker image, you can either build using the provided the [Dockerfile](./docker/Dockerfile) or pull a Docker image from our Docker hub. + +``` +# 1. build from dockerfile +cd docker +docker build -t hpcaitech/diffusion:0.2.0 . + +# 2. pull from our docker hub +docker pull hpcaitech/diffusion:0.2.0 +``` + +Once you have the image ready, you can launch the image with the following command: + +```bash +######################## +# On Your Host Machine # +######################## +# make sure you start your image in the repository root directory +cd Colossal-AI + +# run the docker container +docker run --rm \ + -it --gpus all \ + -v $PWD:/workspace \ + -v :/data/scratch \ + -v :/root/.cache/huggingface \ + hpcaitech/diffusion:0.2.0 \ + /bin/bash + +######################## +# Insider Container # +######################## +# Once you have entered the docker container, go to the stable diffusion directory for training +cd examples/images/diffusion/ + +# start training with colossalai +bash train_colossalai.sh +``` + +It is important for you to configure your volume mapping in order to get the best training experience. +1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. +2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v :/root/.cache/huggingface`, where you need to repliace the `` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`. +3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. + + + +## Download the model checkpoint from pretrained + +### stable-diffusion-v1-4 + +Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style) + +``` +git lfs install +git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +``` + +### stable-diffusion-v1-5 from runway + +If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml + +``` +git lfs install +git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +``` + +## Dataset + +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +you should the change the `data.file_path` in the `config/train_colossalai.yaml` + +## Training + +We provide the script `train_colossalai.sh` to run the training task with colossalai, +and can also use `train_ddp.sh` to run the training task with ddp to compare. + +In `train_colossalai.sh` the main command is: +``` +python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml +``` + +- you can change the `--logdir` to decide where to save the log information and the last checkpoint. + +### Training config + +You can change the trainging config in the yaml file + +- devices: device number used for training, default 8 +- max_epochs: max training epochs, default 2 +- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai +- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai) + +## Finetune Example +### Training on Teyvat Datasets + +We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. + +You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml` +``` +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml +``` + +## Inference +you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by +``` +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms + --outdir ./output \ + --config path/to/logdir/checkpoints/last.ckpt \ + --ckpt /path/to/logdir/configs/project.yaml \ +``` + +```commandline +usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] + [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] + [--seed SEED] [--precision {full,autocast}] + +optional arguments: + -h, --help show this help message and exit + --prompt [PROMPT] the prompt to render + --outdir [OUTDIR] dir to write results to + --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples + --skip_save do not save individual samples. For speed measurements. + --ddim_steps DDIM_STEPS + number of ddim sampling steps + --plms use plms sampling + --laion400m uses the LAION400M model + --fixed_code if enabled, uses the same starting code across samples + --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling + --n_iter N_ITER sample this often + --H H image height, in pixel space + --W W image width, in pixel space + --C C latent channels + --f F downsampling factor + --n_samples N_SAMPLES + how many samples to produce for each given prompt. A.k.a. batch size + --n_rows N_ROWS rows in the grid (default: n_samples) + --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) + --from-file FROM_FILE + if specified, load prompts from this file + --config CONFIG path to config which constructs model + --ckpt CKPT path to checkpoint of model + --seed SEED the seed (for reproducible sampling) + --use_int8 whether to use quantization method + --precision {full,autocast} + evaluate at this precision +``` + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/examples/images/diffusion/configs/Inference/v2-inference-v.yaml b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml new file mode 100644 index 000000000..8ec8dfbfe --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inference.yaml new file mode 100644 index 000000000..152c4f3c2 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml new file mode 100644 index 000000000..32a9471d7 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml new file mode 100644 index 000000000..531199de4 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/x4-upscaling.yaml b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml new file mode 100644 index 000000000..45ecbf9ad --- /dev/null +++ b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml @@ -0,0 +1,75 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Teyvat/README.md b/examples/images/diffusion/configs/Teyvat/README.md new file mode 100644 index 000000000..65ba3fb80 --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/README.md @@ -0,0 +1,8 @@ +# Dataset Card for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml new file mode 100644 index 000000000..d466c1c56 --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -0,0 +1,126 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 16 + num_workers: 4 + train: + target: ldm.data.teyvat.hf_dataset + params: + path: Fazzie/Teyvat + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml new file mode 100644 index 000000000..0354311f8 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -0,0 +1,123 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + wrap: False + # num_workwers should be 2 * batch_size, and total num less than 1024 + # e.g. if use 8 devices, no more than 128 + num_workers: 128 + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: # YOUR DATASET_PATH + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 8 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 000000000..0273ca862 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,127 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml new file mode 100644 index 000000000..a63df887e --- /dev/null +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -0,0 +1,119 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + # num_workwers should be 2 * batch_size, and the total num less than 1024 + # e.g. if use 8 devices, no more than 128 + num_workers: 128 + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: # YOUR DATAPATH + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 8 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.DDPStrategy + params: + find_unused_parameters: False + log_every_n_steps: 2 +# max_steps: 6o + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/data2/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/configs/train_pokemon.yaml b/examples/images/diffusion/configs/train_pokemon.yaml new file mode 100644 index 000000000..aadb5f2a0 --- /dev/null +++ b/examples/images/diffusion/configs/train_pokemon.yaml @@ -0,0 +1,120 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 32 + wrap: False + train: + target: ldm.data.pokemon.PokemonDataset + # params: + # file_path: "/data/scratch/diffuser/laion_part0/" + # world_size: 1 + # rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 1 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: strategies.ColossalAIStrategy + params: + use_chunk: True + enable_distributed_storage: True + placement_policy: cuda + force_outputs_fp32: true + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname diff --git a/examples/images/diffusion/docker/Dockerfile b/examples/images/diffusion/docker/Dockerfile new file mode 100644 index 000000000..17cc8bc8b --- /dev/null +++ b/examples/images/diffusion/docker/Dockerfile @@ -0,0 +1,41 @@ +FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0 + +# install torch +# RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +RUN apt-get update +RUN apt-get install ffmpeg libsm6 libxext6 -y + +# install apex +RUN git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ + +# install colossalai +# RUN git clone https://github.com/hpcaitech/ColossalAI.git \ +# && cd ./ColossalAI \ +# && pip install -v --no-cache-dir . + +RUN pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org + + +# install our lightning, it will be merged to Lightning official repo. +RUN git clone https://github.com/1SAA/lightning.git && \ + cd lightning && \ + git checkout strategy/colossalai && \ + export PACKAGE_NAME=pytorch && \ + pip install --no-cache-dir . + +# install titans +RUN pip install --no-cache-dir titans + +RUN git clone https://github.com/hpcaitech/ColossalAI.git && \ + cd ./ColossalAI/examples/images/diffusion && \ + pip install -r requirements.txt && \ + pip install --no-cache-dir transformers==4.19.2 diffusers invisible-watermark + +# install tensornvme +# RUN conda install cmake && \ +# git clone https://github.com/hpcaitech/TensorNVMe.git && \ +# cd TensorNVMe && \ +# pip install -r requirements.txt && \ +# pip install -v --no-cache-dir . diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml new file mode 100644 index 000000000..69904c72e --- /dev/null +++ b/examples/images/diffusion/environment.yaml @@ -0,0 +1,31 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.9.12 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - albumentations==1.3.0 + - opencv-python==4.6.0.66 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers==4.19.2 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.7.0 + - prefetch_generator + - datasets + - colossalai + - -e . diff --git a/examples/images/diffusion/ldm/.DS_Store b/examples/images/diffusion/ldm/.DS_Store new file mode 100644 index 000000000..647199f9f Binary files /dev/null and b/examples/images/diffusion/ldm/.DS_Store differ diff --git a/examples/images/diffusion/ldm/data/__init__.py b/examples/images/diffusion/ldm/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py new file mode 100644 index 000000000..a12492c95 --- /dev/null +++ b/examples/images/diffusion/ldm/data/base.py @@ -0,0 +1,76 @@ +import math +import os +from abc import abstractmethod + +import cv2 +import numpy as np +import torch +from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + + def __init__(self, file_path: str, rank, world_size): + super().__init__() + self.file_path = file_path + self.folder_list = [] + self.file_list = [] + self.txt_list = [] + self.info = self._get_file_info(file_path) + self.start = self.info['start'] + self.end = self.info['end'] + self.rank = rank + + self.world_size = world_size + # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size))) + # self.iter_start = self.start + self.rank * self.per_worker + # self.iter_end = min(self.iter_start + self.per_worker, self.end) + # self.num_records = self.iter_end - self.iter_start + # self.valid_ids = [i for i in range(self.iter_end)] + self.num_records = self.end - self.start + self.valid_ids = [i for i in range(self.end)] + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + # return self.iter_end - self.iter_start + return self.end - self.start + + def __iter__(self): + sample_iterator = self._sample_generator(self.start, self.end) + # sample_iterator = self._sample_generator(self.iter_start, self.iter_end) + return sample_iterator + + def _sample_generator(self, start, end): + for idx in range(start, end): + file_name = self.file_list[idx] + txt_name = self.txt_list[idx] + f_ = open(txt_name, 'r') + txt_ = f_.read() + f_.close() + image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image) / 255 + yield {"txt": txt_, "image": image} + + def _get_file_info(self, file_path): + info = \ + { + "start": 1, + "end": 0, + } + self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + for folder in self.folder_list: + files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] + txts = [k.replace('jpg', 'txt') for k in files] + self.file_list.extend(files) + self.txt_list.extend(txts) + info['end'] = len(self.file_list) + # with open(file_path, 'r') as fin: + # for _ in enumerate(fin): + # info['end'] += 1 + # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list] + return info diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py new file mode 100644 index 000000000..53cd61263 --- /dev/null +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py new file mode 100644 index 000000000..1c473f9c6 --- /dev/null +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py new file mode 100644 index 000000000..6256e4571 --- /dev/null +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py new file mode 100644 index 000000000..61dc29d56 --- /dev/null +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -0,0 +1,152 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + path = "Fazzie/Teyvat", + image_transforms=[], + image_column="image", + text_column="text", + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(path, name="train") + ds = ds["train"] + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] + ) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + processed[caption_key] = examples[text_column] + + return processed + + ds.set_transform(pre_process) + return ds \ No newline at end of file diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py new file mode 100644 index 000000000..be39da9ca --- /dev/null +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py new file mode 100644 index 000000000..b1bd83778 --- /dev/null +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -0,0 +1,223 @@ +import torch +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + +import torch.nn.functional as F +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/examples/images/diffusion/ldm/models/diffusion/__init__.py b/examples/images/diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 000000000..612a8371b --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import lightning.pytorch as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 000000000..27ead0ea9 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,336 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 000000000..f7ac0a735 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1895 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +try: + import lightning.pytorch as pl + from lightning.pytorch.utilities import rank_zero_only, rank_zero_info +except: + import pytorch_lightning as pl + from pytorch_lightning.utilities import rank_zero_only, rank_zero_info +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid + +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL + + +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import * + +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.encoders.modules import * + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * + + +from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder + +from ldm.util import instantiate_from_config + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16 = True, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + + self.unet_config = unet_config + self.conditioning_key = conditioning_key + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + self.ckpt_path = ckpt_path + self.ignore_keys = ignore_keys + self.load_only_unet = load_only_unet + self.reset_ema = reset_ema + self.reset_num_ema_updates = reset_num_ema_updates + + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.given_betas = given_betas + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.logvar_init = logvar_init + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + self.use_fp16 = use_fp16 + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + if self.use_fp16: + x = x.to(memory_format=torch.contiguous_format).half() + else: + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def configure_sharded_model(self) -> None: + rank_zero_info("Configure sharded model for LatentDiffusion") + self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + if self.use_ema: + self.model_ema = LitEma(self.model) + + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) + if self.reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) + + self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + self.instantiate_first_stage(self.first_stage_config) + self.instantiate_cond_stage(self.cond_stage_config) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z.half() if self.use_fp16 else self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + if self.use_fp16: + x_low = x_low.to(memory_format=torch.contiguous_format).half() + else: + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + if self.use_fp16: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + else: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 000000000..7427f38c0 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 000000000..095e5ba3c --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 000000000..7d137b8cf --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 000000000..7002a365d --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py new file mode 100644 index 000000000..7eff02be6 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py new file mode 100644 index 000000000..d504d939f --- /dev/null +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -0,0 +1,331 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/images/diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 000000000..57b9a4b80 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,857 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +try: + from lightning.pytorch.utilities import rank_zero_info +except: + from pytorch_lightning.utilities import rank_zero_info + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + rank_zero_info("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 000000000..cd639d936 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,787 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = t_emb.type(self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 000000000..038166620 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 000000000..36b4a171b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,273 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +from einops import repeat +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) + + elif schedule == "cosine": + timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) + # args = tuple(inputs) + tuple(params) + # return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled() + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.GroupNorm(16, channels) + # return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/examples/images/diffusion/ldm/modules/distributions/__init__.py b/examples/images/diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 000000000..f2b8ef901 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py new file mode 100644 index 000000000..bded25019 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/examples/images/diffusion/ldm/modules/encoders/__init__.py b/examples/images/diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 000000000..4edd5496b --- /dev/null +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/__init__.py b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 000000000..7836cada8 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 000000000..32ef56169 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 000000000..808c7f882 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,651 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 000000000..4249b43de Binary files /dev/null and b/examples/images/diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 000000000..0175f155a --- /dev/null +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py new file mode 100644 index 000000000..b58ebbffd --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py new file mode 100644 index 000000000..5cf430239 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py new file mode 100644 index 000000000..2145d18fa --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 000000000..4e9aab5d2 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 000000000..8a9549778 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 000000000..50e4acb5e --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py new file mode 100644 index 000000000..350cbc116 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py new file mode 100644 index 000000000..ea46b1be8 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py new file mode 100644 index 000000000..9a9d3b5b6 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py new file mode 100644 index 000000000..8c09ca1c7 --- /dev/null +++ b/examples/images/diffusion/ldm/util.py @@ -0,0 +1,197 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py new file mode 100644 index 000000000..87d495123 --- /dev/null +++ b/examples/images/diffusion/main.py @@ -0,0 +1,826 @@ +import argparse +import csv +import datetime +import glob +import importlib +import os +import sys +import time + +import numpy as np +import torch +import torchvision + +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + +from functools import partial + +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset, Subset, random_split + +try: + from lightning.pytorch import seed_everything + from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from lightning.pytorch.trainer import Trainer + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "lightning.pytorch." +except: + from pytorch_lightning import seed_everything + from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "pytorch_lightning." + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config + +# from ldm.modules.attention import enable_flash_attentions + + +class DataLoaderX(DataLoader): + + def __iter__(self): + return BackgroundGenerator(super().__iter__()) + + +def get_parser(**parser_kwargs): + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument("-p", "--project", help="name of new or path to existing project") + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--use_fp16", + type=str2bool, + nargs="?", + const=True, + default=True, + help="whether to use fp16", + ) + parser.add_argument( + "--flash", + type=str2bool, + const=True, + default=False, + nargs="?", + help="whether to use flash attention", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + + def __init__(self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoaderX(self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn) + + +class SetupCallback(Callback): + + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + # def on_pretrain_routine_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + + def __init__(self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.CSVLogger: self._testtube, + } + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, + batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or + (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + # self.log_img(pl_module, batch, batch_idx, split="train") + pass + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + + def on_train_start(self, trainer, pl_module): + rank_zero_info("Training is starting") + + def on_train_end(self, trainer, pl_module): + rank_zero_info("Training is ending") + + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.strategy.reduce(max_memory) + epoch_time = trainer.strategy.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError("-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint") + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + + print(trainer_config) + if not trainer_config["accelerator"] == "gpu": + del trainer_config["accelerator"] + cpu = True + print("Running on CPU") + else: + cpu = False + print("Running on GPU") + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + use_fp16 = trainer_config.get("precision", 32) == 16 + if use_fp16: + config.model["params"].update({"use_fp16": True}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + else: + config.model["params"].update({"use_fp16": False}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + + model = instantiate_from_config(config.model) + # trainer and callbacks + trainer_kwargs = dict() + + # config the logger + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "tensorboard": { + "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", + "params": { + "save_dir": logdir, + "name": "diff_tb", + "log_graph": True + } + } + } + + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = default_logger_cfg + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # config the strategy, defualt is ddp + if "strategy" in trainer_config: + strategy_cfg = trainer_config["strategy"] + print("Using strategy: {}".format(strategy_cfg["target"])) + strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] + else: + strategy_cfg = { + "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", + "params": { + "find_unused_parameters": False + } + } + print("Using strategy: DDPStrategy") + + trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': { + "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = trainer_config["devices"] + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" + .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + pudb.set_trace() + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + # if not opt.no_test and not trainer.interrupted: + # trainer.test(model, data) + except Exception: + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # move newly created debug project to debug_runs + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt new file mode 100644 index 000000000..60c4b903e --- /dev/null +++ b/examples/images/diffusion/requirements.txt @@ -0,0 +1,18 @@ +albumentations==1.3.0 +opencv-python==4.6.0 +pudb==2019.2 +prefetch_generator +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +torchmetrics==0.6 +omegaconf==2.1.1 +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +transformers==4.19.2 +webdataset==0.2.5 +open-clip-torch==2.7.0 +gradio==3.11 +datasets +colossalai +-e . diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh new file mode 100755 index 000000000..a8d79e99c --- /dev/null +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/examples/images/diffusion/scripts/download_models.sh b/examples/images/diffusion/scripts/download_models.sh new file mode 100755 index 000000000..84297d7b8 --- /dev/null +++ b/examples/images/diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py new file mode 100644 index 000000000..877538d47 --- /dev/null +++ b/examples/images/diffusion/scripts/img2img.py @@ -0,0 +1,296 @@ +"""make variations of input image""" + +import argparse, os +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from imwatermark import WatermarkEncoder + + +from scripts.txt2img import put_watermark +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from utils import replace_module, getModelSize + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2. * image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.8, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") + + +if __name__ == "__main__": + main() + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py new file mode 100644 index 000000000..d6e6387a9 --- /dev/null +++ b/examples/images/diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py new file mode 100644 index 000000000..e6eaaecab --- /dev/null +++ b/examples/images/diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py new file mode 100644 index 000000000..876fe3c36 --- /dev/null +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py new file mode 100644 index 000000000..a32e66d44 --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -0,0 +1,37 @@ +import os +import sys +from copy import deepcopy + +import yaml +from datetime import datetime + +from diffusers import StableDiffusionPipeline +import torch +from ldm.util import instantiate_from_config +from main import get_parser + +if __name__ == "__main__": + with torch.no_grad(): + yaml_path = "../../train_colossalai.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = f.read() + base_config = yaml.load(config, Loader=yaml.FullLoader) + unet_config = base_config['model']['params']['unet_config'] + diffusion_model = instantiate_from_config(unet_config).to("cuda:0") + + pipe = StableDiffusionPipeline.from_pretrained( + "/data/scratch/diffuser/stable-diffusion-v1-4" + ).to("cuda:0") + dif_model_2 = pipe.unet + + random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") + random_input_2 = torch.clone(random_input_).to("cuda:0") + time_stamp = torch.randint(20, (4,)).to("cuda:0") + time_stamp2 = torch.clone(time_stamp).to("cuda:0") + context_ = torch.rand((4, 77, 768)).to("cuda:0") + context_2 = torch.clone(context_).to("cuda:0") + + out_1 = diffusion_model(random_input_, time_stamp, context_) + out_2 = dif_model_2(random_input_2, time_stamp2, context_2) + print(out_1.shape) + print(out_2['sample'].shape) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py new file mode 100644 index 000000000..f93f8a6e7 --- /dev/null +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py new file mode 100644 index 000000000..1e7904889 --- /dev/null +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py new file mode 100644 index 000000000..364ebac6c --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.py @@ -0,0 +1,307 @@ +import argparse, os +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import nullcontext +from imwatermark import WatermarkEncoder + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from utils import replace_module, getModelSize + +torch.set_grad_enabled(False) + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.eval() + return model + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a professional photograph of an astronaut riding a triceratops", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--dpm", + action='store_true', + help="use DPM (2) sampler", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=3, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file, separated by newlines", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v2-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="repeat each prompt in file this often", + ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) + opt = parser.parse_args() + return opt + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def main(opt): + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + model = model.to(device) + + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + + if opt.plms: + sampler = PLMSSampler(model) + elif opt.dpm: + sampler = DPMSolverSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = [p for p in data for i in range(opt.repeat)] + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + sample_count = 0 + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(), \ + precision_scope("cuda"), \ + model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + opt = parse_args() + main(opt) + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/txt2img.sh b/examples/images/diffusion/scripts/txt2img.sh new file mode 100755 index 000000000..549bb03a6 --- /dev/null +++ b/examples/images/diffusion/scripts/txt2img.sh @@ -0,0 +1,6 @@ +python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \ + --outdir ./output \ + --config /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt \ + --ckpt /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml \ + --n_samples 4 + diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py new file mode 100644 index 000000000..c954b22ca --- /dev/null +++ b/examples/images/diffusion/scripts/utils.py @@ -0,0 +1,83 @@ +import bitsandbytes as bnb +import torch.nn as nn +import torch + +class Linear8bit(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=False, + memory_efficient_backward=False, + threshold=6.0, + weight_data=None, + bias_data=None + ): + super(Linear8bit, self).__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.bias = bias_data + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) + self.weight = weight_data + self.quant() + + + def quant(self): + weight = self.weight.data.contiguous().half().cuda() + CB, _, SCB, _, _ = bnb.functional.double_quant(weight) + delattr(self, "weight") + setattr(self, "weight", nn.Parameter(CB, requires_grad=False)) + delattr(self, "SCB") + setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False)) + del weight + + def forward(self, x): + self.state.is_training = self.training + + if self.bias is not None and self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() + + self.state.CB = self.weight.data + self.state.SCB = self.SCB.data + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + del self.state.CxB + return out + +def replace_module(model): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_module(module) + + if isinstance(module, nn.Linear) and "out_proj" not in name: + model._modules[name] = Linear8bit( + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) + return model + +def getModelSize(model): + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + print('Model Size: {:.3f}MB'.format(all_size)) + return (param_size, param_sum, buffer_size, buffer_sum, all_size) diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py new file mode 100644 index 000000000..a24d54167 --- /dev/null +++ b/examples/images/diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh new file mode 100755 index 000000000..4223a6941 --- /dev/null +++ b/examples/images/diffusion/train_colossalai.sh @@ -0,0 +1,5 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh new file mode 100644 index 000000000..78fe76548 --- /dev/null +++ b/examples/images/diffusion/train_ddp.sh @@ -0,0 +1,5 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md new file mode 100644 index 000000000..a306a3abf --- /dev/null +++ b/examples/images/dreambooth/README.md @@ -0,0 +1,107 @@ +# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git) + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel. + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install -r requirements_colossalai.txt +``` + +### Install [colossalai](https://github.com/hpcaitech/ColossalAI.git) + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +**From source** + +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +python setup.py install +``` + +## Dataset for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). + +## Training + +The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。 + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 \ + --placement="cuda" +``` + + +### Training with prior-preservation loss + +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --placement="cuda" +``` + +## Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. + +```python +from diffusers import StableDiffusionPipeline +import torch + +model_id = "path-to-save-model" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh new file mode 100755 index 000000000..227d8b8bd --- /dev/null +++ b/examples/images/dreambooth/colossalai.sh @@ -0,0 +1,22 @@ +export MODEL_NAME= +export INSTANCE_DIR= +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --placement="cuda" \ diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py new file mode 100644 index 000000000..c4adb4823 --- /dev/null +++ b/examples/images/dreambooth/debug.py @@ -0,0 +1,21 @@ +''' +torchrun --standalone --nproc_per_node=1 debug.py +''' + +from diffusers import AutoencoderKL + +import colossalai +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx + +path = "/data/scratch/diffuser/stable-diffusion-v1-4" + +colossalai.launch_from_torch(config={}) +with ColoInitContext(device='cpu'): + vae = AutoencoderKL.from_pretrained( + path, + subfolder="vae", + revision=None, + ) + +for n, p in vae.named_parameters(): + print(n) diff --git a/examples/images/dreambooth/dreambooth.sh b/examples/images/dreambooth/dreambooth.sh new file mode 100644 index 000000000..e063bc827 --- /dev/null +++ b/examples/images/dreambooth/dreambooth.sh @@ -0,0 +1,12 @@ +python train_dreambooth.py \ + --pretrained_model_name_or_path= ## Your Model Path \ + --instance_data_dir= ## Your Training Input Pics Path \ + --output_dir="path-to-save-model" \ + --instance_prompt="a photo of a dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py new file mode 100644 index 000000000..c342821c7 --- /dev/null +++ b/examples/images/dreambooth/inference.py @@ -0,0 +1,12 @@ +from diffusers import StableDiffusionPipeline, DiffusionPipeline +import torch + +model_id = +print(f"Loading model... from{model_id}") + +pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of an apple." +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("output.png") diff --git a/examples/images/dreambooth/requirements.txt b/examples/images/dreambooth/requirements.txt new file mode 100644 index 000000000..6c4f40fb5 --- /dev/null +++ b/examples/images/dreambooth/requirements.txt @@ -0,0 +1,8 @@ +diffusers>==0.5.0 +accelerate +torchvision +transformers>=4.21.0 +ftfy +tensorboard +modelcards +colossalai diff --git a/examples/images/dreambooth/requirements_colossalai.txt b/examples/images/dreambooth/requirements_colossalai.txt new file mode 100644 index 000000000..c4a0e9170 --- /dev/null +++ b/examples/images/dreambooth/requirements_colossalai.txt @@ -0,0 +1,8 @@ +diffusers +torch +torchvision +ftfy +tensorboard +modelcards +transformers +colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org diff --git a/examples/images/dreambooth/train_dreambooth.py b/examples/images/dreambooth/train_dreambooth.py new file mode 100644 index 000000000..b989955f7 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth.py @@ -0,0 +1,694 @@ +import argparse +import hashlib +import itertools +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +logger = get_logger(__name__) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm(sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + + # Load models and create wrapper for stable diffusion + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * + accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, + lr_scheduler) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.save_steps == 0: + if accelerator.is_main_process: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + pipeline.save_pretrained(save_path) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py new file mode 100644 index 000000000..b7e24bfe4 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -0,0 +1,677 @@ +import argparse +import hashlib +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel.utils import get_static_torch_model +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + +disable_existing_loggers() +logger = get_dist_logger() + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default="a photo of sks dog", + required=False, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument( + "--placement", + type=str, + default="cpu", + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): + from colossalai.nn.parallel import GeminiDDP + + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=64) + return model + + +def main(args): + colossalai.launch_from_torch(config={}) + + if args.seed is not None: + gpc.set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + pipeline.to(get_current_device()) + + for example in tqdm( + sample_dataloader, + desc="Generating class images", + disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if gpc.get_local_rank(ParallelMode.DATA) == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + logger.info("Loading tokenizer from pretrained model", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) + + # Load models and create wrapper for stable diffusion + + logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + + logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + with ColoInitContext(device=get_current_device()): + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) + + unet = gemini_zero_dpp(unet, args.placement) + + # config optimizer for colossalai zero + optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + + # load noise_scheduler + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + # prepare dataset + logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(get_current_device(), dtype=weight_dtype) + text_encoder.to(get_current_device(), dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) + logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0) + progress_bar.set_description("Steps") + global_step = 0 + + torch.cuda.synchronize() + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + torch.cuda.reset_peak_memory_stats() + # Move batch to gpu + for key, value in batch.items(): + batch[key] = value.to(get_current_device(), non_blocking=True) + + # Convert images to latent space + optimizer.zero_grad() + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + optimizer.backward(loss) + + optimizer.step() + lr_scheduler.step() + logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) + # Checks if the accelerator has performed an optimization step behind the scenes + progress_bar.update(1) + global_step += 1 + logs = { + "loss": loss.detach().item(), + "lr": optimizer.param_groups[0]["lr"], + } # lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step % args.save_steps == 0: + torch.cuda.synchronize() + torch_unet = get_static_torch_model(unet) + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=torch_unet, + revision=args.revision, + ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + pipeline.save_pretrained(save_path) + logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) + if global_step >= args.max_train_steps: + break + + torch.cuda.synchronize() + unet = get_static_torch_model(unet) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + revision=args.revision, + ) + + pipeline.save_pretrained(args.output_dir) + logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/images/dreambooth/train_dreambooth_inpaint.py b/examples/images/dreambooth/train_dreambooth_inpaint.py new file mode 100644 index 000000000..774cd4c45 --- /dev/null +++ b/examples/images/dreambooth/train_dreambooth_inpaint.py @@ -0,0 +1,720 @@ +import argparse +import hashlib +import itertools +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image, ImageDraw +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +logger = get_logger(__name__) + + +def prepare_mask_and_masked_image(image, mask): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +# generate random masks +def random_mask(im_shape, ratio=1, mask_full_image=False): + mask = Image.new("L", im_shape, 0) + draw = ImageDraw.Draw(mask) + size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio))) + # use this to always mask the whole image + if mask_full_image: + size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) + limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) + center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1])) + draw_type = random.randint(0, 1) + if draw_type == 0 or mask_full_image: + draw.rectangle( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) + else: + draw.ellipse( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) + + return mask + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), + ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), + ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=("Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + example["PIL_images"] = instance_image + example["instance_images"] = self.image_transforms(instance_image) + + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_PIL_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, + batch_size=args.sample_batch_size, + num_workers=1) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + transform_to_pil = transforms.ToPILImage() + for example in tqdm(sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process): + bsz = len(example["prompt"]) + fake_images = torch.rand((3, args.resolution, args.resolution)) + transform_to_pil = transforms.ToPILImage() + fake_pil_images = transform_to_pil(fake_images) + + fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True) + + images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * + accelerator.num_processes) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + image_transforms = transforms.Compose([ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ]) + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + pior_pil = [example["class_PIL_images"] for example in examples] + + masks = [] + masked_images = [] + for example in examples: + pil_image = example["PIL_images"] + # generate a random mask + mask = random_mask(pil_image.size, 1, False) + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) + + if args.with_prior_preservation: + for pil_image in pior_pil: + # generate a random mask + mask = random_mask(pil_image.size, 1, False) + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + masks = torch.stack(masks) + masked_images = torch.stack(masked_images) + batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} + return batch + + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, + lr_scheduler) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Convert masked images to latent space + masked_latents = vae.encode(batch["masked_images"].reshape( + batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = masked_latents * 0.18215 + + masks = batch["masks"] + # resize the mask to latents shape as we concatenate the mask to the latents + mask = torch.stack([ + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) + for mask in masks + ]) + mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # concatenate the noised latents with the mask and the masked latents + latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder else unet.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md new file mode 100644 index 000000000..4423d85d1 --- /dev/null +++ b/examples/images/vit/README.md @@ -0,0 +1,61 @@ +# Vision Transformer with ColoTensor + +# Overview + +In this example, we will run Vision Transformer with ColoTensor. + +We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. +You can change world size or decide whether use DDP in our code. + +We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. + +(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. + +# Requirement + +Install colossalai version >= 0.1.11 + +## Unit test +To run unit test, you should install pytest, transformers with: +```shell +pip install pytest transformers +``` + +## Training example +To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. +You also need to install timm and titans for model/dataloader support with: +```shell +pip install timm titans +``` + +### Data preparation +You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. + +Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. +```shell +export DATA=/path/to/ILSVRC2012 +``` + + +# How to run + +## Unit test +In your terminal +```shell +pytest test_vit.py +``` + +This will evaluate models with different **world_size** and **use_ddp**. + +## Training example +Modify the settings in run.sh according to your environment. +For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, +data parallel size will be automatically calculated as 4. +Thus, the parallel strategy is set to 4DP+2TP. + +Then in your terminal +```shell +sh run.sh +``` + +This will start ViT-S training with ImageNet. diff --git a/examples/images/vit/configs/vit_1d_tp2.py b/examples/images/vit/configs/vit_1d_tp2.py new file mode 100644 index 000000000..fbf399f2e --- /dev/null +++ b/examples/images/vit/configs/vit_1d_tp2.py @@ -0,0 +1,32 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 300 +WARMUP_EPOCHS = 32 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 384 +DEPTH = 12 +NUM_HEADS = 6 +MLP_RATIO = 4 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +USE_DDP = True +TP_WORLD_SIZE = 2 +TP_TYPE = 'row' +parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +gradient_accumulation = 8 + +LOG_PATH = "./log" diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/images/vit/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/images/vit/run.sh b/examples/images/vit/run.sh new file mode 100644 index 000000000..84fe58f11 --- /dev/null +++ b/examples/images/vit/run.sh @@ -0,0 +1,15 @@ +export DATA=/data/scratch/imagenet/tf_records +export OMP_NUM_THREADS=4 + +# resume +# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +# --nproc_per_node 4 train.py \ +# --config configs/vit_1d_tp2.py \ +# --resume_from checkpoint/epoch_10 \ +# --master_port 29598 | tee ./out 2>&1 + +# train +CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ +--nproc_per_node 4 train.py \ +--config configs/vit_1d_tp2.py \ +--master_port 29598 | tee ./out 2>&1 diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py new file mode 100644 index 000000000..90f2475b8 --- /dev/null +++ b/examples/images/vit/test_vit.py @@ -0,0 +1,164 @@ +import os +import random +from functools import partial + +import numpy as np +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from vit import get_training_components + +import colossalai +from colossalai.context import ParallelMode +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def tensor_equal(A, B): + return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise + + +# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. +# But for other layers, it's 1d_col split. +# Layernorm is not supported for now. +# patch_embeddings.projection has nn.Conv2d +# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n + or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p, p) + + +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if (torch_p.grad.shape == p.grad.shape): + assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True + else: + dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True + + +def run_vit(init_spec_func, use_ddp): + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + if use_ddp: + model = ColoDDP(model) + torch_model = DDP(torch_model, + device_ids=[gpc.get_global_rank()], + process_group=gpc.get_group(ParallelMode.DATA)) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + + world_size = torch.distributed.get_world_size() + init_spec_func(model, world_size) + + check_param_equal(model, torch_model) + model.train() + torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + + optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + logits = model(image_dict['pixel_values']) + torch_logits = torch_model(image_dict['pixel_values']) + assert tensor_equal(torch_logits.logits, logits.logits) + loss = criterion(logits.logits, image_dict['label']) + torch_loss = criterion(torch_logits.logits, image_dict['label']) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model) + optimizer.step() + torch_optimizer.step() + check_param_equal(model, torch_model) + break + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit(init_1d_row_for_linear_weight_spec, use_ddp) + run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_vit(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_vit(1, False) diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py new file mode 100644 index 000000000..de39801c7 --- /dev/null +++ b/examples/images/vit/train.py @@ -0,0 +1,161 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import _create_vision_transformer +from titans.dataloader.imagenet import build_dali_imagenet +from tqdm import tqdm + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn._ops import * +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def init_1d_row_for_linear_weight_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): + pg = ProcessGroup(tp_degree=world_size) + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n + and 'patch_embed.proj.bias' not in n): + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def init_spec_func(model, tp_type): + world_size = torch.distributed.get_world_size() + if tp_type == 'row': + init_1d_row_for_linear_weight_spec(model, world_size) + elif tp_type == 'col': + init_1d_col_for_linear_weight_bias_spec(model, world_size) + else: + raise NotImplemented + + +def train_imagenet(): + + parser = colossalai.get_default_parser() + parser.add_argument('--from_torch', default=True, action='store_true') + parser.add_argument('--resume_from', default=False) + + args = parser.parse_args() + colossalai.launch_from_torch(config=args.config) + use_ddp = gpc.config.USE_DDP + + disable_existing_loggers() + + logger = get_dist_logger() + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + logger.info('Build data loader', ranks=[0]) + root = os.environ['DATA'] + train_dataloader, test_dataloader = build_dali_imagenet(root, + train_batch_size=gpc.config.BATCH_SIZE, + test_batch_size=gpc.config.BATCH_SIZE) + + logger.info('Build model', ranks=[0]) + + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + embed_dim=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=gpc.config.NUM_CLASSES, + drop_rate=0.1, + attn_drop_rate=0.1, + weight_init='jax') + + with ColoInitContext(device=get_current_device()): + model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) + init_spec_func(model, gpc.config.TP_TYPE) + + world_size = torch.distributed.get_world_size() + model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) + logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) + optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + criterion = CrossEntropyLoss() + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + start_epoch = 0 + if args.resume_from: + load_model = torch.load(args.resume_from + '_model.pth') + start_epoch = load_model['epoch'] + model.load_state_dict(load_model['model']) + load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) + optimizer.load_state_dict(load_optim['optim']) + + for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): + model.train() + for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + loss = criterion(output, y) + loss = loss / gpc.config.gradient_accumulation + if use_ddp: + model.backward(loss) + else: + loss.backward() + if (index + 1) % gpc.config.gradient_accumulation == 0: + optimizer.step() + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + + logger.info( + f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", + ranks=[0]) + + model.eval() + test_loss = 0 + correct = 0 + test_sum = 0 + with torch.no_grad(): + for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): + x, y = x.cuda(), y.cuda() + output = model(x) + test_loss += F.cross_entropy(output, y, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(y.view_as(pred)).sum().item() + test_sum += y.size(0) + + test_loss /= test_sum + logger.info( + f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", + ranks=[0]) + + lr_scheduler.step() + + +if __name__ == '__main__': + train_imagenet() diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py new file mode 100644 index 000000000..14c870b39 --- /dev/null +++ b/examples/images/vit/vit.py @@ -0,0 +1,92 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +from transformers import ViTConfig, ViTForImageClassification + +from colossalai.utils.cuda import get_current_device + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +class DummyDataLoader(DummyDataGenerator): + batch_size = 4 + channel = 3 + category = 8 + image_size = 224 + + def generate(self): + image_dict = {} + image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, + DummyDataLoader.channel, + DummyDataLoader.image_size, + DummyDataLoader.image_size, + device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + dtype=torch.int64, + device=get_current_device()) + return image_dict + + +class ViTCVModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + image_size=224, + patch_size=16, + num_channels=3, + num_labels=8, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = ViTForImageClassification( + ViTConfig(hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_labels=num_labels)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, pixel_values): + return self.model(pixel_values=pixel_values) + + +def vit_base_s(checkpoint=True): + return ViTCVModel(checkpoint=checkpoint) + + +def vit_base_micro(checkpoint=True): + return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) + + +def get_training_components(): + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy diff --git a/examples/language/commons/utils.py b/examples/language/commons/utils.py new file mode 100644 index 000000000..782f546dc --- /dev/null +++ b/examples/language/commons/utils.py @@ -0,0 +1,12 @@ +import torch + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md new file mode 100644 index 000000000..8fdf6be3b --- /dev/null +++ b/examples/language/gpt/README.md @@ -0,0 +1,72 @@ +# Train GPT with Colossal-AI + +This example shows how to use [Colossal-AI](https://github.com/hpcaitech/ColossalAI) to run huggingface GPT training in distributed manners. + +## GPT + +We use the [GPT-2](https://huggingface.co/gpt2) model from huggingface transformers. The key learning goal of GPT-2 is to use unsupervised pre-training models to do supervised tasks.GPT-2 has an amazing performance in text generation, and the generated text exceeds people's expectations in terms of contextual coherence and emotional expression. + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.12+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231. +If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-AI>=0.1.12. + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training +We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism. +The other one uses Pipeline Parallelism Only. +In the future, we are going merge them together and they can be used orthogonally to each other. + +### GeminiDPP/ZeRO + Tensor Parallelism +```bash +bash run_gemini.sh +``` + +The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run_gemini.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP. + +- Colossal-AI +- ZeRO1 (Colossal-AI) +- ZeRO2 (Colossal-AI) +- Pytorch DDP +- Pytorch ZeRO + + +## Performance + +Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e. +ColossalAI version 0.1.13. + +[benchmark results on google doc](https://docs.google.com/spreadsheets/d/15A2j3RwyHh-UobAPv_hJgT4W_d7CnlPm5Fp4yEzH5K4/edit#gid=0) + +[benchmark results on Tencent doc (for china)](https://docs.qq.com/sheet/DUVpqeVdxS3RKRldk?tab=BB08J2) + +### Experimental Features + +#### [Pipeline Parallel](./experiments/pipeline_parallel/) +#### [Auto Parallel](./experiments/auto_parallel_with_gpt/) diff --git a/examples/language/gpt/experiments/auto_parallel/README.md b/examples/language/gpt/experiments/auto_parallel/README.md new file mode 100644 index 000000000..404c83911 --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/README.md @@ -0,0 +1,44 @@ +# Auto-Parallelism with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +### Install pulp and coin-or-cbc + +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training + +```bash +#Run the auto parallel resnet example with 4 GPUs with a dummy dataset. +colossalai run --nproc_per_node 4 auto_parallel_with_gpt.py +``` diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py new file mode 100644 index 000000000..85c8d64d7 --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -0,0 +1,109 @@ +from functools import partial +from time import time +from typing import Dict, Optional, Tuple, Union + +import psutil +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from gpt_modules import GPT2LMHeadModel, GPTLMLoss +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger + +BATCH_SIZE = 8 +SEQ_LENGTH = 128 +HIDDEN_DIM = 3072 +NUM_HEADS = 16 +NUM_LAYERS = 1 +VOCAB_SIZE = 50257 +NUM_STEPS = 10 +FP16 = False + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def main(): + disable_existing_loggers() + launch_from_torch(config={}) + logger = get_dist_logger() + config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) + if FP16: + model = GPT2LMHeadModel(config=config).half().to('cuda') + else: + model = GPT2LMHeadModel(config=config).to('cuda') + global_numel = sum([p.numel() for p in model.parameters()]) + + meta_input_sample = { + 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + } + + # Both device mesh initialization and model initialization will be integrated into autoparallelize + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # Enable auto-parallel + gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True) + + # print solution on rank 0 + if gpc.get_global_rank() == 0: + for node_strategy in solution: + print(node_strategy) + + # build criterion + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) + torch.cuda.synchronize() + model.train() + + for n in range(10): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = gm(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + loss.backward() + optimizer.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py new file mode 100644 index 000000000..95feaec38 --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py @@ -0,0 +1,253 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.type(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + + qkv = self.c_attn(hidden_states) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + return attn_output + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # residual connection + hidden_states = attn_outputs + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2Model(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + hidden_states = inputs_embeds + position_embeds + + output_shape = input_shape + (hidden_states.size(-1),) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + + return hidden_states + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + lm_logits = self.lm_head(transformer_outputs) + + return lm_logits + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/examples/language/gpt/experiments/auto_parallel/requirements.txt b/examples/language/gpt/experiments/auto_parallel/requirements.txt new file mode 100644 index 000000000..ff046ad1c --- /dev/null +++ b/examples/language/gpt/experiments/auto_parallel/requirements.txt @@ -0,0 +1,4 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +transformers >= 4.231 +PuLP >= 2.7.0 diff --git a/examples/language/gpt/experiments/pipeline_parallel/README.md b/examples/language/gpt/experiments/pipeline_parallel/README.md new file mode 100644 index 000000000..702e3c8d6 --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/README.md @@ -0,0 +1,38 @@ +# Pipeline Parallelism Demo with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training + +```bash +#Run the Pipeline Parallel on GPT with default setting and a dummy dataset. +#You can change the GPU number or microbatch number in the run.sh . +bash run.sh +``` diff --git a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py new file mode 100644 index 000000000..c31b3fa6d --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py @@ -0,0 +1,73 @@ +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.config = GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size) + self.model = GPT2LMHeadModel(self.config) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_14b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_20b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_24b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def model_builder(model_size: str) -> callable: + if model_size == "gpt2_medium": + return gpt2_medium + elif model_size == "gpt2_xl": + return gpt2_xl + elif model_size == "gpt2_10b": + return gpt2_10b + elif model_size == "gpt2_14b": + return gpt2_14b + elif model_size == "gpt2_20b": + return gpt2_20b + elif model_size == "gpt2_24b": + return gpt2_24b + else: + raise TypeError(f"model_builder {model_size}") + + +__all__ = ['model_builder'] diff --git a/examples/language/gpt/experiments/pipeline_parallel/run.sh b/examples/language/gpt/experiments/pipeline_parallel/run.sh new file mode 100644 index 000000000..235cefcbc --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/run.sh @@ -0,0 +1,7 @@ +export GPUNUM=${GPUNUM:-4} +export BATCH_SIZE=${BATCH_SIZE:-16} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} +export NUM_MICROBATCH=${NUM_MICROBATCH:-8} + +mkdir -p pp_logs +python train_gpt_pp.py --device="cuda" --model_type=${MODEL_TYPE} --num_microbatches=${NUM_MICROBATCH} --world_size=${GPUNUM} --batch_size=${BATCH_SIZE} 2>&1 | tee ./pp_logs/${MODEL_TYPE}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_nm_${NUM_MICROBATCH}.log diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py new file mode 100644 index 000000000..79efa61b0 --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -0,0 +1,161 @@ +import argparse +import time +from functools import partial + +import torch +from model_zoo import model_builder +from torch import nn +from tqdm import tqdm + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.pipeline.middleware.adaptor import get_fx_topology +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.pipeline.rpc.utils import rpc_run + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_type', type=str, default="gpt2_medium") + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29011') + parser.add_argument('--num_worker_threads', type=int, default=128) + return parser.parse_args() + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = avgnode_split_pass(gm, stage_num) + + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank + 1] + + +def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int): + module = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return module + + +def run_master(args): + batch_size = args.batch_size + device = args.device + world_size = args.world_size + stage_num = world_size + num_microbatches = args.num_microbatches + model_type = args.model_type + # batch size per DP degree + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + WARMUP_STEPS = 1 + + disable_existing_loggers() + logger = get_dist_logger() + logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", + ranks=[0]) + + torch.manual_seed(123) + + # build criterion + criterion = GPTLMLoss() + + # warm up pipeline fx partition + input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) + warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} + + # create model + model = model_builder(model_type)(checkpoint=False) + + # set 1f1b pipeline engine + pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False) + + partition_numels = pp_engine.remote_numels() + for rank, numel in partition_numels.items(): + logger.info(f'{rank=} numel in the partition:{numel}') + + # build optim + pp_engine.initialize_optimizer(HybridAdam, lr=1e-3) + + ranks_tflops = {} + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) + batch = {'input_ids': input_ids, 'attention_mask': attn_mask} + + start = time.time() + outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) + step_time = time.time() - start + + for rank, numel in partition_numels.items(): + if rank not in ranks_tflops: + ranks_tflops[rank] = [] + step_tflops = get_tflops(numel, batch_size, SEQ_LEN, step_time) + + logger.info( + f"Rank{rank} , [{n + 1}/{NUM_STEPS}] , Step time: {step_time:.3f}s, TFLOPS: {get_tflops(numel, batch_size, SEQ_LEN, step_time):.3f}", + ranks=[0], + ) + + if n >= WARMUP_STEPS: + ranks_tflops[rank].append(step_tflops) + + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + gpu_tflops = [] + for rank, tflops_list in ranks_tflops.items(): + tflops_list.sort() + gpu_tflops.append(tflops_list[median_index]) + logger.info(f"GPU{rank} Median TFLOPS is {tflops_list[median_index]:.3f}") + + logger.info(f"Total TFLOPS is {sum(gpu_tflops):.3f}") + logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") + + +if __name__ == '__main__': + args = parse_args() + rpc_run(args, run_master) diff --git a/examples/language/gpt/gemini/benchmark_gemini.sh b/examples/language/gpt/gemini/benchmark_gemini.sh new file mode 100644 index 000000000..464ea03da --- /dev/null +++ b/examples/language/gpt/gemini/benchmark_gemini.sh @@ -0,0 +1,22 @@ +for MODEL_TYPE in "gpt2_medium"; do + for DISPAN in "colossalai"; do + for BATCH_SIZE in 16; do + for GPUNUM in 1 2 4 8; do + for TPDEGREE in 1 2 4 8; do + if [ ${TPDEGREE} -gt ${GPUNUM} ]; then + continue + fi + for PLACEMENT in "cpu" "auto"; do + echo "****************** Begin ***************************" + echo "+ benchmrking MODEL ${MODEL_TYPE} DISPAN ${DISPAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}" + MODEL_TYPE=${MODEL_TYPE} DISPAN=${DISPAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ + bash ./run_gemini.sh + echo "****************** Finished ***************************" + echo "" + echo "" + done + done + done + done + done +done diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py new file mode 100644 index 000000000..c31b3fa6d --- /dev/null +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -0,0 +1,73 @@ +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.config = GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size) + self.model = GPT2LMHeadModel(self.config) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_14b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_20b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_24b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def model_builder(model_size: str) -> callable: + if model_size == "gpt2_medium": + return gpt2_medium + elif model_size == "gpt2_xl": + return gpt2_xl + elif model_size == "gpt2_10b": + return gpt2_10b + elif model_size == "gpt2_14b": + return gpt2_14b + elif model_size == "gpt2_20b": + return gpt2_20b + elif model_size == "gpt2_24b": + return gpt2_24b + else: + raise TypeError(f"model_builder {model_size}") + + +__all__ = ['model_builder'] diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py new file mode 100644 index 000000000..782f546dc --- /dev/null +++ b/examples/language/gpt/gemini/commons/utils.py @@ -0,0 +1,12 @@ +import torch + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh new file mode 100644 index 000000000..ad577c350 --- /dev/null +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -0,0 +1,24 @@ +set -x +# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"] +export DISTPAN=${DISTPAN:-"colossalai"} + +# The following options only valid when DISTPAN="colossalai" +export GPUNUM=${GPUNUM:-1} +export TPDEGREE=${TPDEGREE:-1} +export PLACEMENT=${PLACEMENT:-"cpu"} +export USE_SHARD_INIT=${USE_SHARD_INIT:-False} +export BATCH_SIZE=${BATCH_SIZE:-16} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} + +# export PYTHONPATH=$PWD:$PYTHONPATH + +mkdir -p gemini_logs + +torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ +--tp_degree=${TPDEGREE} \ +--model_type=${MODEL_TYPE} \ +--batch_size=${BATCH_SIZE} \ +--placement=${PLACEMENT} \ +--shardinit=${USE_SHARD_INIT} \ +--distplan=${DISTPAN} \ +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py new file mode 100644 index 000000000..891b1de15 --- /dev/null +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -0,0 +1,358 @@ +import os +from functools import partial +from time import time + +import psutil +import torch +import torch.nn as nn +from commons.model_zoo import model_builder +from commons.utils import get_data, get_tflops +from packaging import version +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + +CAI_VERSION = colossalai.__version__ + +if version.parse(CAI_VERSION) > version.parse("0.1.10"): + # These are added after 0.1.10 + from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer + from colossalai.nn.parallel import GeminiDDP + from colossalai.zero.sharded_optim import LowLevelZeroOptimizer + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + parser.add_argument( + "--model_type", + type=str, + default="gpt2_medium", + help="model model scale", + ) + args = parser.parse_args() + return args + + +# Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +def model_size_formatter(numel: int) -> str: + GB_SIZE = 10**9 + MB_SIZE = 10**6 + KB_SIZE = 10**3 + if numel >= GB_SIZE: + return f'{numel / GB_SIZE:.1f}B' + elif numel >= MB_SIZE: + return f'{numel / MB_SIZE:.1f}M' + elif numel >= KB_SIZE: + return f'{numel / KB_SIZE:.1f}K' + else: + return str(numel) + + +def set_cpu_maximum_parallelism(): + conf_str = torch.__config__.parallel_info() + inter_str = conf_str.split("hardware_concurrency() : ")[1] + max_concurrency = inter_str.split('\n')[0] + os.environ["OMP_NUM_THREADS"] = max_concurrency + print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # NOTE() a param maybe shared by two modules + if hasattr(param, 'visited'): + continue + + # if shard init, then convert param to replica and use the dp-only ProcessGroup + param: ColoParameter = param + param.set_dist_spec(ReplicaSpec()) + param.set_process_group(pg) + + # shard it w.r.t tp pattern + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice + else: + param.set_dist_spec(ReplicaSpec()) + param.visited = True + + +# Gemini + ZeRO DDP +def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): + fp16_init_scale = 2**5 + gpu_margin_mem_ratio_for_auto = 0 + + if version.parse(CAI_VERSION) > version.parse("0.1.10"): + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placement_policy, + pin_memory=True, + hidden_dim=model.config.n_embd, + search_range_mb=64) + # configure the const policy + if placement_policy == 'const': + model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024) + # build a highly optimized cpu optimizer + optimizer = GeminiAdamOptimizer(model, + lr=1e-3, + initial_scale=fp16_init_scale, + gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto) + elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"): + from colossalai.gemini import ChunkManager, GeminiManager + from colossalai.nn.optimizer import HybridAdam + from colossalai.zero import ZeroOptimizer + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placement_policy)) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = ZeroOptimizer(optimizer, + model, + initial_scale=fp16_init_scale, + gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto) + else: + raise NotImplemented(f"CAI version {CAI_VERSION} is not supported") + return model, optimizer + + +def main(): + # version check + # this example is supposed to work for versions greater than 0.1.9 + assert version.parse(CAI_VERSION) >= version.parse("0.1.9") + + set_cpu_maximum_parallelism() + args = parse_args() + + if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: + raise TypeError(f"{args.distplan} is error") + + # batch size per DP degree + BATCH_SIZE = args.batch_size + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + NUM_STEPS = 10 + WARMUP_STEPS = 1 + assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" + assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median " + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + logger = get_dist_logger() + logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + if args.distplan == "colossalai": + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + # build GPT model + if version.parse(CAI_VERSION) > version.parse("0.1.10"): + with ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = model_builder(args.model_type)(checkpoint=True) + else: + with ColoInitContext(device=get_current_device()): + model = model_builder(args.model_type)(checkpoint=True) + + tp_pg = ProcessGroup(tp_degree=args.tp_degree) + # Tensor Parallelism (TP) + # You should notice that v0.1.10 is not compatible with TP degree > 1 + tensor_parallelize(model, tp_pg) + + # build a Gemini model and a highly optimized cpu optimizer + # Gemini + ZeRO DP, Note it must be used after TP + model, optimizer = build_gemini(model, tp_pg, args.placement) + + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + else: + assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." + model = model_builder(args.model_type)(checkpoint=True).cuda() + + if args.distplan.startswith("torch"): + model = DDP(model) + if args.distplan.endswith("ddp"): + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + elif args.distplan.endswith("zero"): + from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) + elif args.distplan.startswith("zero"): + partition_flag = args.distplan == "zero2" + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + optimizer = LowLevelZeroOptimizer(optimizer, + overlap_communication=True, + partition_grad=partition_flag, + verbose=True) + + # model is shared after TP + numel = get_model_size(model) + logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) + # = batch_per_DP_group * numel * seq_len * 8 + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + torch.cuda.synchronize() + model.train() + tflops_list = [] + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + + start = time() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + torch.cuda.synchronize() + fwd_end = time() + fwd_time = fwd_end - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + + if args.distplan in ["colossalai", "zero1", "zero2"]: + optimizer.backward(loss) + elif args.distplan in ["torch_ddp", "torch_zero"]: + loss.backward() + torch.cuda.synchronize() + bwd_end = time() + bwd_time = bwd_end - fwd_end + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + + if args.distplan in ["zero1", "zero2"]: + optimizer.sync_grad() + optimizer.step() + torch.cuda.synchronize() + optim_time = time() - bwd_end + step_time = time() - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if n >= WARMUP_STEPS: + tflops_list.append(step_tflops) + + tflops_list.sort() + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt new file mode 100644 index 000000000..e1f131468 --- /dev/null +++ b/examples/language/gpt/requirements.txt @@ -0,0 +1 @@ +transformers >= 4.23 diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh new file mode 100644 index 000000000..ad0cfa325 --- /dev/null +++ b/examples/language/gpt/test_ci.sh @@ -0,0 +1,16 @@ +pip install -r requirements.txt + +# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=2 +export GPUNUM=4 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False +export BATCH_SIZE=8 +export MODEL_TYPE="gpt2_medium" + + +mkdir -p logs +torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md new file mode 100644 index 000000000..c2fd25457 --- /dev/null +++ b/examples/language/opt/README.md @@ -0,0 +1,33 @@ + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## Quick Start +You can launch training by using the following bash script + +```bash +bash ./run_gemini.sh +``` diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh new file mode 100644 index 000000000..0d04b5e9b --- /dev/null +++ b/examples/language/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh +done +done +done +done diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh new file mode 100644 index 000000000..d9625723a --- /dev/null +++ b/examples/language/opt/run_gemini.sh @@ -0,0 +1,20 @@ +set -x +export BS=${BS:-16} +export MEMCAP=${MEMCAP:-0} +# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b` +export MODEL=${MODEL:-"125m"} +export GPUNUM=${GPUNUM:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + train_gemini_opt.py \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py new file mode 100755 index 000000000..64426ba42 --- /dev/null +++ b/examples/language/opt/train_gemini_opt.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import time +from functools import partial + +import datasets +import torch +import torch.distributed as dist +import transformers +from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM +from transformers.utils.versions import require_version + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch({}) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + torch.mannul_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev, dtype=torch.half): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev, dtype=torch.half): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + numel = sum([p.numel() for p in model.parameters()]) + PLACEMENT_POLICY = 'cpu' + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) + + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) + + model.train() + for step in range(args.max_train_steps): + st_time = time.time() + input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) + + outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) + loss = outputs['loss'] + optimizer.backward(loss) + + optimizer.step() + optimizer.zero_grad() + torch.cuda.synchronize() + step_time = time.time() - st_time + step_tflops = get_tflops_func(step_time) + + logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0]) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md new file mode 100644 index 000000000..486bf240f --- /dev/null +++ b/examples/language/palm/README.md @@ -0,0 +1,64 @@ + + +## PaLM - Pytorch + +Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code. + +This model is pretty much SOTA on everything language. + +It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is. + +## Install +```bash +$ pip install PaLM-pytorch +``` + +## Usage + +```python +import torch +from palm_pytorch import PaLM + +palm = PaLM( + num_tokens = 20000, + dim = 512, + depth = 12, + heads = 8, + dim_head = 64, +) + +tokens = torch.randint(0, 20000, (1, 2048)) +logits = palm(tokens) # (1, 2048, 20000) +``` + +The PaLM 540B in the paper would be + +```python +palm = PaLM( + num_tokens = 256000, + dim = 18432, + depth = 118, + heads = 48, + dim_head = 256 +) +``` + +## Test on Enwik8 + +```bash +$ python train.py +``` + +## Todo + +- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer + +## Citations + +```bibtex +@article{chowdhery2022PaLM, + title = {PaLM: Scaling Language Modeling with Pathways}, + author = {Chowdhery, Aakanksha et al}, + year = {2022} +} +``` diff --git a/examples/language/palm/data/README.md b/examples/language/palm/data/README.md new file mode 100644 index 000000000..56433b4dc --- /dev/null +++ b/examples/language/palm/data/README.md @@ -0,0 +1,3 @@ +# Data source + +The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ diff --git a/examples/language/palm/palm_pytorch/__init__.py b/examples/language/palm/palm_pytorch/__init__.py new file mode 100644 index 000000000..dab49645a --- /dev/null +++ b/examples/language/palm/palm_pytorch/__init__.py @@ -0,0 +1 @@ +from palm_pytorch.palm_pytorch import PaLM diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py new file mode 100644 index 000000000..dc4f3d856 --- /dev/null +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -0,0 +1,77 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helper function + + +def exists(val): + return val is not None + + +def eval_decorator(fn): + + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +# top k filtering + + +def top_k(logits, thres=0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(1, ind, val) + return probs + + +class AutoregressiveWrapper(nn.Module): + + def __init__(self, net, max_seq_len=2048, pad_value=0): + super().__init__() + self.max_seq_len = max_seq_len + self.pad_value = pad_value + self.net = net + + @torch.no_grad() + @eval_decorator + def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs): + b, t, device = *start_tokens.shape, start_tokens.device + + out = start_tokens + + for _ in range(seq_len): + logits = self.net(out, **kwargs)[:, -1, :] + + filtered_logits = top_k(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + if exists(eos_token): + is_eos_token = out == eos_token + + if is_eos_token.any(dim=-1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 + out = out.masked_fill(mask, self.pad_value) + break + + out = out[:, t:] + return out + + def forward(self, x, **kwargs): + x_inp, x_labels = x[:, :-1], x[:, 1:] + logits = self.net(x_inp, **kwargs) + return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels) diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py new file mode 100644 index 000000000..c37974711 --- /dev/null +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -0,0 +1,207 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import einsum, matmul, nn + +# normalization +# they use layernorm without bias, something that pytorch does not offer + + +class LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +# parallel with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + + +class ParallelResidual(nn.Module): + + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + return x + sum([fn(x) for fn in self.fns]) + + +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device) + #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) + freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# feedforward +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + LayerNorm(dim), + nn.Linear(dim, inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +# attention +class Attention(nn.Module): + + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + inner_dim = dim_head * heads + self.norm = LayerNorm(dim) + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # for caching causal mask and rotary embeddings + + self.register_buffer("mask", None, persistent=False) + self.register_buffer("pos_emb", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("position", pos_emb, persistent=False) + return pos_emb + + def forward(self, x): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # queries, keys, values + + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + + q = q * self.scale + + b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1) + + # similarity + + #sim = einsum("b h i d, b j d -> b h i j", q, k) + sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2)) + sim = sim.reshape(b, h, i, j) + + # causal mask + + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2) + + # aggregate values + + #out = einsum("b h i j, b j d -> b h i d", attn, v) + out = matmul(attn.reshape(b_, h_ * i_, j_), v) + out = out.reshape(b_, h_, i_, d_) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +# transformer + + +def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): + net = nn.Sequential( + nn.Embedding(num_tokens, dim), *[ + ParallelResidual( + Attention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ) for _ in range(depth) + ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + + # they used embedding weight tied projection out to logits, not common, but works + net[-1].weight = net[0].weight + + nn.init.normal_(net[0].weight, std=0.02) + return net diff --git a/examples/language/palm/requirements.txt b/examples/language/palm/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/language/palm/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh new file mode 100644 index 000000000..4aa868953 --- /dev/null +++ b/examples/language/palm/run.sh @@ -0,0 +1,11 @@ +# distplan in ["colossalai", "pytorch"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=1 +export GPUNUM=1 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False +export BATCH_SIZE=4 + +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log \ No newline at end of file diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py new file mode 100644 index 000000000..7c080b7f3 --- /dev/null +++ b/examples/language/palm/train.py @@ -0,0 +1,252 @@ +import gzip +import random + +import numpy as np +import torch +import torch.optim as optim +import tqdm +from packaging import version +from palm_pytorch import PaLM +from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import MultiTimer, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + +# constants + +NUM_BATCHES = int(1000) +GRADIENT_ACCUMULATE_EVERY = 1 +LEARNING_RATE = 2e-4 +VALIDATE_EVERY = 100 +GENERATE_EVERY = 500 +GENERATE_LENGTH = 512 +SEQ_LEN = 1024 + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, pytorch].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + args = parser.parse_args() + return args + +# helpers +def cycle(loader): + while True: + for data in loader: + yield data + + +def decode_token(token): + return str(chr(max(32, token))) + + +def decode_tokens(tokens): + return "".join(list(map(decode_token, tokens))) + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model + +## Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'net.0' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_q' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_kv' in mn: + split_param_row_tp1d(param, pg) # row slice + elif 'to_out' in mn: + split_param_row_tp1d(param, pg) # row slice + elif '1.1' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif '1.2' in mn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True + + +args = parse_args() +if args.distplan not in ["colossalai", "pytorch"]: + raise TypeError(f"{args.distplan} is error") +disable_existing_loggers() +colossalai.launch_from_torch(config={}) + +with gzip.open("./data/enwik8.gz") as file: + X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) + trX, vaX = np.split(X, [int(90e6)]) + data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) + + +class TextSamplerDataset(Dataset): + + def __init__(self, data, seq_len): + super().__init__() + self.data = data + self.seq_len = seq_len + + def __getitem__(self, index): + rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) + full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + return full_seq.cuda() + + def __len__(self): + return self.data.size(0) // self.seq_len + + +train_dataset = TextSamplerDataset(data_train, SEQ_LEN) +val_dataset = TextSamplerDataset(data_val, SEQ_LEN) +train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size)) +val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) + +if args.distplan == "colossalai": + # instantiate GPT-like decoder model + + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + + with ctx: + model = PaLM(num_tokens=256, dim=512, depth=8) + model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) + + pg = default_pg + tensor_parallelize(model, pg) + model = gemini_zero_dpp(model, pg, args.placement) + + #optimizer + + #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) + optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) +else: + model = PaLM(num_tokens=256, dim=512, depth=8) + model = AutoregressiveWrapper(model, max_seq_len=2048) + model.cuda() + optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + + + +# training +model.train() + +for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): + + if args.distplan == "colossalai": + optimizer.zero_grad() + + loss = model(next(train_loader)) + # loss.backward() + optimizer.backward(loss) + + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + # optim.step() + # optim.zero_grad() + optimizer.step() + else: + for __ in range(GRADIENT_ACCUMULATE_EVERY): + loss = model(next(train_loader)) + loss.backward() + + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() + + # TODO + # if i % VALIDATE_EVERY == 0: + # model.eval() + # with torch.no_grad(): + # loss = model(next(val_loader)) + # print(f"validation loss: {loss.item()}") + + # if i % GENERATE_EVERY == 0: + # model.eval() + # inp = random.choice(val_dataset)[:-1] + # prime = decode_tokens(inp) + # print(f"%s \n\n %s", (prime, "*" * 100)) + + # sample = model.generate(inp[None, ...], GENERATE_LENGTH) + # output_str = decode_tokens(sample[0]) + # print(output_str) \ No newline at end of file diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md new file mode 100644 index 000000000..a42b1935d --- /dev/null +++ b/examples/language/roberta/README.md @@ -0,0 +1,58 @@ +# Introduction +This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert. + +## 0. Prerequisite +- Install Colossal-AI +- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes" +- Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times + +``` +ssh-keygen +ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination +``` + +- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. + +```bash +192.168.2.1 GPU001 +192.168.2.2 GPU002 +192.168.2.3 GPU003 +192.168.2.4 GPU004 +192.168.2.5 GPU005 +192.168.2.6 GPU006 +192.168.2.7 GPU007 +... +``` + +- restart ssh +``` +service ssh restart +``` + +## 1. Corpus Preprocessing +```bash +cd preprocessing +``` +following the `README.md`, preprocess original corpus to h5py+numpy + +## 2. Pretrain + +```bash +cd pretraining +``` +following the `README.md`, load the h5py generated by preprocess of step 1 to pretrain the model + +## 3. Finetune + +The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application. + +## Contributors +The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! + +``` +@misc{ + title={A simple Chinese RoBERTa Example for Whole Word Masked}, + author={Yehua Zhang, Chen Zhang}, + year={2022} +} +``` diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py new file mode 100644 index 000000000..c3c59aa40 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_ddp.py @@ -0,0 +1,4 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py new file mode 100644 index 000000000..c5debdce0 --- /dev/null +++ b/examples/language/roberta/configs/colossalai_zero.py @@ -0,0 +1,32 @@ +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.nn.optimizer import FusedAdam + +# fp16 = dict( +# mode=AMP_TYPE.TORCH, +# ) + +# seed = 2 +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + reduce_scatter_bucket_size_mb=25, + fp32_reduce_scatter=False, + tensor_placement_policy="cuda", + gradient_predivide_factor=1.0, + reuse_fp16_shard=False), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, + initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32)) + +# gradient_accumulation = 4 +clip_grad_norm = 1.0 +optimizer = dict( + type=FusedAdam, + lr=0.00015, + weight_decay=1e-2, +) + +# 64433 \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/Makefile b/examples/language/roberta/preprocessing/Makefile new file mode 100644 index 000000000..82ee4e1c5 --- /dev/null +++ b/examples/language/roberta/preprocessing/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = mask +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/language/roberta/preprocessing/README.md b/examples/language/roberta/preprocessing/README.md new file mode 100644 index 000000000..1dbd745ab --- /dev/null +++ b/examples/language/roberta/preprocessing/README.md @@ -0,0 +1,105 @@ +# Data PreProcessing for chinese Whole Word Masked + + + +## Catalogue: +* 1. Introduction +* 2. Quick Start Guide: + * 2.1. Split Sentence + * 2.2.Tokenizer & Whole Word Masked + + + + +## 1. Introduction: [Back to Top] +This folder is used to preprocess chinese corpus with Whole Word Masked. You can obtain corpus from [WuDao](https://resource.wudaoai.cn/home?ind&name=WuDaoCorpora%202.0&id=1394901288847716352). Moreover, data preprocessing is flexible, and you can modify the code based on your needs, hardware or parallel framework(Open MPI, Spark, Dask). + + + +## 2. Quick Start Guide: [Back to Top] + + + +### 2.1. Split Sentence & Split data into multiple shard: +Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. +In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** + +```python +python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100 +# This step takes a short time +``` +* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ... +* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--shard`: Number of shard, e.g., 10, 50, or 100 + +Input json: + +``` +[ + { + "id": 0, + "title": "打篮球", + "content": "我今天去打篮球。不回来吃饭。" + } + { + "id": 1, + "title": "旅游", + "content": "我后天去旅游。下周请假。" + } +] +``` + +Output txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + + + +### 2.2. Tokenizer & Whole Word Masked: + +```python +python tokenize_mask.py --input_path /shard --output_path /h5 --tokenizer_path /roberta --backend python +# This step is time consuming and is mainly spent on mask +``` + +**[optional but recommended]**: the C++ backend with `pybind11` can provide faster speed + +```shell +make +``` + +* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... +* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** +* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document +* `--worker`: number of process + +Input txt: + +``` +我今天去打篮球。 +不回来吃饭。 +]] +我后天去旅游。 +下周请假。 +``` + +Output h5+numpy: + +``` +'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], + ...] +'input_mask': [[1,1,1,1,1,1,0,0..], + ...] +'segment_ids': [[0,0,0,0,0,...], + ...] +'masked_lm_positions': [[label1,-1,-1,label2,-1...], + ...] +``` \ No newline at end of file diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py new file mode 100644 index 000000000..da297f98e --- /dev/null +++ b/examples/language/roberta/preprocessing/get_mask.py @@ -0,0 +1,266 @@ +import torch +import os +from enum import IntEnum +from random import choice +import random +import collections +import time +import logging +import jieba +jieba.setLogLevel(logging.CRITICAL) +import re +import numpy as np +import mask + +PAD = 0 +MaskedLMInstance = collections.namedtuple("MaskedLMInstance", + ["index", "label"]) + + +def map_to_numpy(data): + return np.asarray(data) + + +class PreTrainingDataset(): + def __init__(self, + tokenizer, + max_seq_length, + backend='python', + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True): + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.masked_lm_prob = 0.15 + self.backend = backend + self.do_whole_word_mask = do_whole_word_mask + self.max_predictions_per_seq = max_predictions_per_seq + self.vocab_words = list(tokenizer.vocab.keys()) + self.rec = re.compile('[\u4E00-\u9FA5]') + self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + + self.mlm_p = 0.15 + self.mlm_mask_p = 0.8 + self.mlm_tamper_p = 0.05 + self.mlm_maintain_p = 0.1 + + + def tokenize(self, doc): + temp = [] + for d in doc: + temp.append(self.tokenizer.tokenize(d)) + return temp + + + def create_training_instance(self, instance): + is_next = 1 + raw_text_list = self.get_new_segment(instance) + tokens_a = raw_text_list + assert len(tokens_a) == len(instance) + # tokens_a, tokens_b, is_next = instance.get_values() + # print(f'is_next label:{is_next}') + # Create mapper + tokens = [] + original_tokens = [] + segment_ids = [] + tokens.append("[CLS]") + original_tokens.append('[CLS]') + segment_ids.append(0) + for index, token in enumerate(tokens_a): + tokens.append(token) + original_tokens.append(instance[index]) + segment_ids.append(0) + + tokens.append("[SEP]") + original_tokens.append('[SEP]') + segment_ids.append(0) + + # for token in tokens_b: + # tokens.append(token) + # segment_ids.append(1) + + # tokens.append("[SEP]") + # segment_ids.append(1) + + # Get Masked LM predictions + if self.backend == 'c++': + output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, + self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) + elif self.backend == 'python': + output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) + + # Convert to Ids + input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens) + input_mask = [1] * len(input_ids) + + while len(input_ids) < self.max_seq_length: + input_ids.append(PAD) + segment_ids.append(PAD) + input_mask.append(PAD) + masked_lm_output.append(-1) + return ([ + map_to_numpy(input_ids), + map_to_numpy(input_mask), + map_to_numpy(segment_ids), + map_to_numpy(masked_lm_output), + map_to_numpy([is_next]) + ]) + + + def create_masked_lm_predictions(self, tokens): + cand_indexes = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + # cand_indexes.append(i) + + random.shuffle(cand_indexes) + output_tokens = list(tokens) + + num_to_predict = min( + self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + if index in covered_indexes: + continue + covered_indexes.add(index) + + masked_token = None + # 80% mask + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% Keep Original + if random.random() < 0.5: + masked_token = tokens[index] + # 10% replace w/ random word + else: + masked_token = self.vocab_words[random.randint( + 0, + len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + masked_lms.append( + MaskedLMInstance(index=index, label=tokens[index])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) + + + def get_new_segment(self, segment): + """ + 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。 + :param segment: 一句话 + :return: 一句处理过的话 + """ + seq_cws = jieba.lcut(''.join(segment)) + seq_cws_dict = {x: 1 for x in seq_cws} + new_segment = [] + i = 0 + while i < len(segment): + if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。 + new_segment.append(segment[i]) + i += 1 + continue + + has_add = False + for length in range(3, 0, -1): + if i + length > len(segment): + continue + if ''.join(segment[i: i+length]) in seq_cws_dict: + new_segment.append(segment[i]) + for l in range(1, length): + new_segment.append('##' + segment[i+l]) + i += length + has_add = True + break + if not has_add: + new_segment.append(segment[i]) + i += 1 + return new_segment + + + def create_whole_masked_lm_predictions(self, tokens): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and + token.startswith("##")): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + random.shuffle(cand_indexes) + + output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##" + + num_to_predict = min(self.max_predictions_per_seq, + max(1, int(round(len(tokens) * self.masked_lm_prob)))) + + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random.random() < 0.5: + masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##" + # 10% of the time, replace with random word + else: + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) + assert len(masked_lms) <= num_to_predict + masked_lms = sorted(masked_lms, key=lambda x: x.index) + masked_lm_output = [-1] * len(output_tokens) + for p in masked_lms: + masked_lm_output[p.index] = self.tokenizer.vocab[p.label] + + return (output_tokens, masked_lm_output) diff --git a/examples/language/roberta/preprocessing/mask.cpp b/examples/language/roberta/preprocessing/mask.cpp new file mode 100644 index 000000000..8355c45cf --- /dev/null +++ b/examples/language/roberta/preprocessing/mask.cpp @@ -0,0 +1,184 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +const int32_t LONG_SENTENCE_LEN = 512; + +struct MaskedLMInstance { + int index; + std::string label; + MaskedLMInstance(int index, std::string label) { + this->index = index; + this->label = label; + } +}; + +auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab + std::unordered_set seq_cws_dict; + for (auto word : segment_jieba) { + seq_cws_dict.insert(word); + } + int i = 0; + std::vector new_segment; + int segment_size = segment.size(); + while (i < segment_size) { + if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end() + new_segment.emplace_back(segment[i]); + i += 1; + continue; + } + bool has_add = false; + for (int length = 3; length >= 1; length--) { + if (i + length > segment_size) { + continue; + } + std::string chinese_word = ""; + for (int j = i; j < i + length; j++) { + chinese_word += segment[j]; + } + if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { + new_segment.emplace_back(segment[i]); + for (int j = i + 1; j < i + length; j++) { + new_segment.emplace_back("##" + segment[j]); + } + i += length; + has_add = true; + break; + } + } + if (!has_add) { + new_segment.emplace_back(segment[i]); + i += 1; + } + } + + return new_segment; +} + +bool startsWith(const std::string& s, const std::string& sub) { + return s.find(sub) == 0 ? true : false; +} + +auto create_whole_masked_lm_predictions(std::vector &tokens, + const std::vector &original_tokens, + const std::vector &vocab_words, + std::map &vocab, + const int max_predictions_per_seq, + const double masked_lm_prob) { + // for (auto item : vocab) { + // std::cout << "key=" << std::string(py::str(item.first)) << ", " + // << "value=" << std::string(py::str(item.second)) << std::endl; + // } + std::vector > cand_indexes; + std::vector cand_temp; + int tokens_size = tokens.size(); + std::string prefix = "##"; + bool do_whole_masked = true; + + for (int i = 0; i < tokens_size; i++) { + if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { + continue; + } + if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) { + cand_temp.emplace_back(i); + } + else { + if (cand_temp.size() > 0) { + cand_indexes.emplace_back(cand_temp); + } + cand_temp.clear(); + cand_temp.emplace_back(i); + } + } + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed)); + // for (auto i : cand_indexes) { + // for (auto j : i) { + // std::cout << tokens[j] << " "; + // } + // std::cout << std::endl; + // } + // for (auto i : output_tokens) { + // std::cout << i; + // } + // std::cout << std::endl; + + int num_to_predict = std::min(max_predictions_per_seq, + std::max(1, int(tokens_size * masked_lm_prob))); + // std::cout << num_to_predict << std::endl; + + std::set covered_indexes; + std::vector masked_lm_output(tokens_size, -1); + int vocab_words_len = vocab_words.size(); + std::default_random_engine e(seed); + std::uniform_real_distribution u1(0.0, 1.0); + std::uniform_int_distribution u2(0, vocab_words_len - 1); + int mask_cnt = 0; + std::vector output_tokens; + output_tokens = original_tokens; + + for (auto index_set : cand_indexes) { + if (mask_cnt > num_to_predict) { + break; + } + int index_set_size = index_set.size(); + if (mask_cnt + index_set_size > num_to_predict) { + continue; + } + bool is_any_index_covered = false; + for (auto index : index_set) { + if (covered_indexes.find(index) != covered_indexes.end()) { + is_any_index_covered = true; + break; + } + } + if (is_any_index_covered) { + continue; + } + for (auto index : index_set) { + + covered_indexes.insert(index); + std::string masked_token; + if (u1(e) < 0.8) { + masked_token = "[MASK]"; + } + else { + if (u1(e) < 0.5) { + masked_token = output_tokens[index]; + } + else { + int random_index = u2(e); + masked_token = vocab_words[random_index]; + } + } + // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); + masked_lm_output[index] = vocab[output_tokens[index]]; + output_tokens[index] = masked_token; + mask_cnt++; + } + } + + // for (auto p : masked_lms) { + // masked_lm_output[p.index] = vocab[p.label]; + // } + return std::make_tuple(output_tokens, masked_lm_output); +} + +PYBIND11_MODULE(mask, m) { + m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions); + m.def("get_new_segment", &get_new_segment); +} diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py new file mode 100644 index 000000000..231be152b --- /dev/null +++ b/examples/language/roberta/preprocessing/sentence_split.py @@ -0,0 +1,163 @@ + +import multiprocessing +import os +import re +from tqdm import tqdm +from typing import List +import json +import time +import argparse +import functools + +def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: + """ + Args: + document: + flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句 + limit: 默认单句最大长度为510个字符 + Returns: Type:list + """ + sent_list = [] + try: + if flag == "zh": + document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号 + elif flag == "en": + document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符 + document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号 + else: + document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + + document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', + document) # 特殊引号 + + sent_list_ori = document.splitlines() + for sent in sent_list_ori: + sent = sent.strip() + if not sent: + continue + elif len(sent) <= 2: + continue + else: + while len(sent) > limit: + temp = sent[0:limit] + sent_list.append(temp) + sent = sent[limit:] + sent_list.append(sent) + except: + sent_list.clear() + sent_list.append(document) + return sent_list + + +def get_sent(output_path, + input_path, + fin_list=[], host=-1, seq_len=512) -> None: + + workers = 32 + + if input_path[-1] == '/': + input_path = input_path[:-1] + + cur_path = os.path.join(output_path, str(host) + '.txt') + new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) + with open(cur_path, 'w', encoding='utf-8') as f: + for fi, fin_path in enumerate(fin_list): + if not os.path.exists(os.path.join(input_path, fin_path[0])): + continue + if '.json' not in fin_path[0]: + continue + + print("Processing ", fin_path[0], " ", fi) + + with open(os.path.join(input_path, fin_path[0]), 'r') as fin: + f_data = [l['content'] for l in json.load(fin)] + + pool = multiprocessing.Pool(workers) + all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) + pool.close() + print('finished..') + + cnt = 0 + for d in tqdm(all_sent): + for i in d: + f.write(i.strip() + '\n') + f.write(']]' + '\n') + cnt += 1 + # if cnt >= 2: + # exit() + + +def getFileSize(filepath, shard): + all_data = [] + for i in os.listdir(filepath): + all_data.append(os.path.join(filepath, i)) + all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) + ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = sorted(ans, key=lambda x: x[1], reverse=True) + per_size = all_size / shard + real_shard = [] + temp = [] + accu_size = 0 + for i in ans: + accu_size += i[1] + temp.append(i) + if accu_size > per_size: + real_shard.append(temp) + accu_size = 0 + temp = [] + + if len(temp) > 0: + real_shard.append(temp) + + return real_shard + + +def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): + import socket + host = int(socket.gethostname().split(server_name)[-1]) + + fin_list = real_shard[server_num * base + host - 1] + print(fin_list) + print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + return fin_list, host + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') + parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + args = parser.parse_args() + + server_num = args.server_num + seq_len = args.seq_len + shard = args.shard + input_path = args.input_path + output_path = args.output_path + + real_shard = getFileSize(input_path, shard) + + start = time.time() + for index, shard in enumerate(real_shard): + get_sent(output_path, + input_path, + fin_list=shard, + host=index, + seq_len=seq_len) + print(f'cost {str(time.time() - start)}') + + # if you have multiple server, you can use code below or modify code to openmpi + + # for i in range(len(real_shard) // server_num + 1): + # fin_list, host = get_start_end(real_shard, i) + + # start = time.time() + # get_sent(output_path, + # input_path, + # fin_list=fin_list, host= 10 * i + host - 1) + + # print(f'cost {str(time.time() - start)}') diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py new file mode 100644 index 000000000..b33871d5d --- /dev/null +++ b/examples/language/roberta/preprocessing/tokenize_mask.py @@ -0,0 +1,275 @@ +import time +import os +import psutil +import h5py +import socket +import argparse +import numpy as np +import multiprocessing +from tqdm import tqdm +from random import shuffle +from transformers import AutoTokenizer +from get_mask import PreTrainingDataset + + +def get_raw_instance(document, max_sequence_length=512): + + """ + 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 + :param document: 一整段 + :param max_sequence_length: + :return: a list. each element is a sequence of text + """ + # document = self.documents[index] + max_sequence_length_allowed = max_sequence_length - 2 + # document = [seq for seq in document if len(seq)= max_sequence_length_allowed: + if len(curr_seq) > 0: + result_list.append(curr_seq) + curr_seq = [] + result_list.append(document[sz_idx][ : max_sequence_length_allowed]) + sz_idx += 1 + else: + result_list.append(curr_seq) + curr_seq = [] + # 对最后一个序列进行处理,如果太短的话,丢弃掉。 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + result_list.append(curr_seq) + + # # 计算总共可以得到多少份 + # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 + # print("num_instance:",num_instance) + # # 切分成多份,添加到列表中 + # result_list=[] + # for j in range(num_instance): + # index=j*max_sequence_length_allowed + # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 + # result_list.append(big_list[index:end_index]) + return result_list + + +def split_numpy_chunk(path, tokenizer, pretrain_data, host): + + documents = [] + instances = [] + + s = time.time() + with open(path, encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + # document = line + # if len(document.split("")) <= 3: + # continue + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print('read_file ', time.time() - s) + + # documents = [x for x in documents if x] + # print(len(documents)) + # print(len(documents[0])) + # print(documents[0][0:10]) + from typing import List + import multiprocessing + + ans = [] + for docs in tqdm(documents): + ans.append(pretrain_data.tokenize(docs)) + print(time.time() - s) + del documents + + instances = [] + for a in tqdm(ans): + raw_ins = get_raw_instance(a) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + sen_num = len(instances) + seq_len = 512 + input_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sen_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32) + + for index, ins in tqdm(enumerate(instances)): + mask_dict = pretrain_data.create_training_instance(ins) + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + + with h5py.File(f'/output/{host}.h5', 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_ids) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +def split_numpy_chunk_pool(input_path, + output_path, + pretrain_data, + worker, + dupe_factor, + seq_len, + file_name): + + if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): + print(f'{file_name}.h5 exists') + return + + documents = [] + instances = [] + + s = time.time() + with open(input_path, 'r', encoding='utf-8') as fd: + document = [] + for i, line in enumerate(tqdm(fd)): + line = line.strip() + if len(line + ) > 0 and line[:2] == "]]": # This is end of document + documents.append(document) + document = [] + elif len(line) >= 2: + document.append(line) + if len(document) > 0: + documents.append(document) + print(f'read_file cost {time.time() - s}, length is {len(documents)}') + + ans = [] + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + ans.append(res) + pool.close() + print((time.time() - s) / 60) + del documents + + instances = [] + for a in tqdm(ans, colour='MAGENTA'): + raw_ins = get_raw_instance(a, max_sequence_length=seq_len) + instances.extend(raw_ins) + del ans + + print('len instance', len(instances)) + + new_instances = [] + for _ in range(dupe_factor): + for ins in instances: + new_instances.append(ins) + + shuffle(new_instances) + instances = new_instances + print('after dupe_factor, len instance', len(instances)) + + sentence_num = len(instances) + input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32) + segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) + masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32) + + s = time.time() + pool = multiprocessing.Pool(worker) + encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + input_ids[index] = mask_dict[0] + input_mask[index] = mask_dict[1] + segment_ids[index] = mask_dict[2] + masked_lm_output[index] = mask_dict[3] + pool.close() + print((time.time() - s) / 60) + + with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_mask) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) + + del instances + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') + parser.add_argument('--seq_len', type=int, default=512, help='sequence length') + parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') + parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') + parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') + parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') + parser.add_argument('--worker', type=int, default=32, help='number of process') + parser.add_argument('--server_num', type=int, default=10, help='number of servers') + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + pretrain_data = PreTrainingDataset(tokenizer, + args.seq_len, + args.backend, + max_predictions_per_seq=args.max_predictions_per_seq) + + + data_len = len(os.listdir(args.input_path)) + + for i in range(data_len): + input_path = os.path.join(args.input_path, f'{i}.txt') + if os.path.exists(input_path): + start = time.time() + print(f'process {input_path}') + split_numpy_chunk_pool(input_path, + args.output_path, + pretrain_data, + args.worker, + args.dupe_factor, + args.seq_len, + i) + end_ = time.time() + print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + print(f'has cost {(end_ - start) / 60}') + print('-' * 100) + print('') + + # if you have multiple server, you can use code below or modify code to openmpi + + # host = int(socket.gethostname().split('GPU')[-1]) + # for i in range(data_len // args.server_num + 1): + # h = args.server_num * i + host - 1 + # input_path = os.path.join(args.input_path, f'{h}.txt') + # if os.path.exists(input_path): + # start = time.time() + # print(f'I am server {host}, process {input_path}') + # split_numpy_chunk_pool(input_path, + # args.output_path, + # pretrain_data, + # args.worker, + # args.dupe_factor, + # args.seq_len, + # h) + # end_ = time.time() + # print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + # print(f'has cost {(end_ - start) / 60}') + # print('-' * 100) + # print('') + + diff --git a/examples/language/roberta/pretraining/README.md b/examples/language/roberta/pretraining/README.md new file mode 100644 index 000000000..055d69696 --- /dev/null +++ b/examples/language/roberta/pretraining/README.md @@ -0,0 +1,24 @@ +# Pretraining +1. Pretraining roberta through running the script below. Detailed parameter descriptions can be found in the arguments.py. `data_path_prefix` is absolute path specifies output of preprocessing. **You have to modify the *hostfile* according to your cluster.** + +```bash +bash run_pretrain.sh +``` +* `--hostfile`: servers' host name from /etc/hosts +* `--include`: servers which will be used +* `--nproc_per_node`: number of process(GPU) from each server +* `--data_path_prefix`: absolute location of train data, e.g., /h5/0.h5 +* `--eval_data_path_prefix`: absolute location of eval data +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json, e.g./tokenizer/tokenizer.json +* `--bert_config`: config.json which represent model +* `--mlm`: model type of backbone, bert or deberta_v2 + +2. if resume training from earylier checkpoint, run the script below. + +```shell +bash run_pretrain_resume.sh +``` +* `--resume_train`: whether to resume training +* `--load_pretrain_model`: absolute path which contains model checkpoint +* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint + diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py new file mode 100644 index 000000000..3a9370e00 --- /dev/null +++ b/examples/language/roberta/pretraining/arguments.py @@ -0,0 +1,152 @@ +import colossalai +from numpy import require + +__all__ = ['parse_args'] + + +def parse_args(): + parser = colossalai.get_default_parser() + + parser.add_argument( + '--lr', + type=float, + required=True, + help='initial learning rate') + parser.add_argument( + '--epoch', + type=int, + required=True, + help='number of epoch') + parser.add_argument( + '--data_path_prefix', + type=str, + required=True, + help="location of the train data corpus") + parser.add_argument( + '--eval_data_path_prefix', + type=str, + required=True, + help='location of the evaluation data corpus') + parser.add_argument( + '--tokenizer_path', + type=str, + required=True, + help='location of the tokenizer') + parser.add_argument( + '--max_seq_length', + type=int, + default=512, + help='sequence length') + parser.add_argument( + '--refresh_bucket_size', + type=int, + default=1, + help= + "This param makes sure that a certain task is repeated for this time steps to \ + optimise on the back propogation speed with APEX's DistributedDataParallel") + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help= + "The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument( + "--gradient_accumulation_steps", + default=1, + type=int, + help="accumulation_steps") + parser.add_argument( + "--train_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="train batch size") + parser.add_argument( + "--eval_micro_batch_size_per_gpu", + default=2, + type=int, + required=True, + help="eval batch size") + parser.add_argument( + "--num_workers", + default=8, + type=int, + help="") + parser.add_argument( + "--async_worker", + action='store_true', + help="") + parser.add_argument( + "--bert_config", + required=True, + type=str, + help="location of config.json") + parser.add_argument( + "--wandb", + action='store_true', + help="use wandb to watch model") + parser.add_argument( + "--wandb_project_name", + default='roberta', + help="wandb project name") + parser.add_argument( + "--log_interval", + default=100, + type=int, + help="report interval") + parser.add_argument( + "--log_path", + type=str, + required=True, + help="log file which records train step") + parser.add_argument( + "--tensorboard_path", + type=str, + required=True, + help="location of tensorboard file") + parser.add_argument( + "--colossal_config", + type=str, + required=True, + help="colossal config, which contains zero config and so on") + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="location of saving checkpoint, which contains model and optimizer") + parser.add_argument( + '--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument( + '--vscode_debug', + action='store_true', + help="use vscode to debug") + parser.add_argument( + '--load_pretrain_model', + default='', + type=str, + help="location of model's checkpoin") + parser.add_argument( + '--load_optimizer_lr', + default='', + type=str, + help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") + parser.add_argument( + '--resume_train', + action='store_true', + help="whether resume training from a early checkpoint") + parser.add_argument( + '--mlm', + default='bert', + type=str, + help="model type, bert or deberta") + parser.add_argument( + '--checkpoint_activations', + action='store_true', + help="whether to use gradient checkpointing") + + args = parser.parse_args() + return args diff --git a/examples/language/roberta/pretraining/bert_dataset_provider.py b/examples/language/roberta/pretraining/bert_dataset_provider.py new file mode 100644 index 000000000..1d8cf2a91 --- /dev/null +++ b/examples/language/roberta/pretraining/bert_dataset_provider.py @@ -0,0 +1,15 @@ +class BertDatasetProviderInterface: + def get_shard(self, index, shuffle=True): + raise NotImplementedError + + def release_shard(self, index): + raise NotImplementedError + + def prefetch_shard(self, index): + raise NotImplementedError + + def get_batch(self, batch_iter): + raise NotImplementedError + + def prefetch_batch(self): + raise NotImplementedError diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py new file mode 100644 index 000000000..83f94082f --- /dev/null +++ b/examples/language/roberta/pretraining/evaluation.py @@ -0,0 +1,71 @@ +import os +import math +import torch +from tqdm import tqdm +from utils.global_vars import get_timers, get_tensorboard_writer +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider + +def evaluate(engine, args, logger, global_step): + evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) + start_shard = 0 + + engine.eval() + timers = get_timers() + eval_step = 0 + eval_loss = 0 + cur_loss = 0 + world_size = torch.distributed.get_world_size() + + with torch.no_grad(): + + for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): + + timers('eval_shard_time').start() + + dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) + # evaluate_dataset_provider.prefetch_shard(shard + 1) + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + eval_step += 1 + input_ids = batch_data[0].cuda() + attention_mask = batch_data[1].cuda() + token_type_ids = batch_data[2].cuda() + mlm_label = batch_data[3].cuda() + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label)#prediction_scores + evaluate_dataset_provider.prefetch_batch() + + eval_loss += loss.float().item() + + cur_loss = eval_loss / eval_step + elapsed_time = timers("eval_shard_time").elapsed() + elapsed_time_per_iteration = elapsed_time / eval_step + ppl = math.exp(cur_loss) + + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'loss': cur_loss, + 'ppl': ppl, + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + + logger.info(eval_log_str) + logger.info('-' * 100) + logger.info('') + + evaluate_dataset_provider.release_shard() + engine.train() + return cur_loss diff --git a/examples/language/roberta/pretraining/hostfile b/examples/language/roberta/pretraining/hostfile new file mode 100644 index 000000000..f4e047f01 --- /dev/null +++ b/examples/language/roberta/pretraining/hostfile @@ -0,0 +1,10 @@ +GPU001 +GPU002 +GPU003 +GPU004 +GPU005 +GPU006 +GPU007 +GPU008 +GPU009 +GPU010 diff --git a/examples/language/roberta/pretraining/loss.py b/examples/language/roberta/pretraining/loss.py new file mode 100644 index 000000000..dc4f872a7 --- /dev/null +++ b/examples/language/roberta/pretraining/loss.py @@ -0,0 +1,17 @@ +import torch + +__all__ = ['LossForPretraining'] + + +class LossForPretraining(torch.nn.Module): + + def __init__(self, vocab_size): + super(LossForPretraining, self).__init__() + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): + masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) + # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) + total_loss = masked_lm_loss #+ next_sentence_loss + return total_loss diff --git a/examples/language/roberta/pretraining/model/bert.py b/examples/language/roberta/pretraining/model/bert.py new file mode 100644 index 000000000..67c85f760 --- /dev/null +++ b/examples/language/roberta/pretraining/model/bert.py @@ -0,0 +1,1893 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + + +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +# TokenClassification docstring +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" +_TOKEN_CLASS_EXPECTED_OUTPUT = ( + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) +_TOKEN_CLASS_EXPECTED_LOSS = 0.01 + +# QuestionAnswering docstring +_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2" +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" +_SEQ_CLASS_EXPECTED_LOSS = 0.01 + + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_QA, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/model/deberta_v2.py b/examples/language/roberta/pretraining/model/deberta_v2.py new file mode 100644 index 000000000..c6ce82847 --- /dev/null +++ b/examples/language/roberta/pretraining/model/deberta_v2.py @@ -0,0 +1,1631 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +import math +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import softmax_backward_data +from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config +from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Byte"], + ) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + # rel = nn.Parameter(torch.empty((pos_ebd_size, config.hidden_size))) + # self.rel_embeddings = nn.init.normal_(rel, mean=0.0, std=config.initializer_range) + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + att_span = self.position_buckets + rel_index = torch.arange(0, att_span * 2).long().to(self.rel_embeddings.weight.device) + rel_embeddings = self.rel_embeddings(rel_index) + # rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + # rel_embeddings = self.rel_embeddings if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + # self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + # rel_index = torch.arange(0, att_span * 2).long().to(query_layer.device) + # rel_embeddings = rel_embeddings(rel_index).unsqueeze(0) + rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings.unsqueeze(0) + # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior.``` + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py new file mode 100644 index 000000000..cce836913 --- /dev/null +++ b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -0,0 +1,182 @@ +import os +import random +import h5py +import logging +import json +import time +from concurrent.futures import ProcessPoolExecutor + +import numpy as np + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.sampler import RandomSampler +from torch.utils.data.distributed import DistributedSampler + +from bert_dataset_provider import BertDatasetProviderInterface +import colossalai.utils as utils + +# Workaround because python functions are not picklable +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, max_predictions_per_seq, + num_workers, train_batch_size, worker_init, + data_sampler): + train_data = pretraining_dataset( + input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) + train_dataloader = DataLoader(train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True + ) + return train_dataloader, len(train_data) + + +class pretraining_dataset(Dataset): + def __init__(self, input_file, max_predictions_per_seq): + self.input_file = input_file + self.max_predictions_per_seq = max_predictions_per_seq + f = h5py.File(input_file, "r") + keys = [ + 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions' + ] + self.inputs = [np.asarray(f[key][:]) for key in keys] + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs[0]) + + def __getitem__(self, index): + + [ + input_ids, input_mask, segment_ids, masked_lm_labels + ] = [ + torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else + torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) + ] + + return [ + input_ids, input_mask, + segment_ids, masked_lm_labels + ] + + +class NvidiaBertDatasetProvider(BertDatasetProviderInterface): + def __init__(self, args, evaluate=False): + self.num_workers = args.num_workers + self.max_seq_length = args.max_seq_length + self.max_predictions_per_seq = args.max_predictions_per_seq + + self.gradient_accumulation_steps = args.gradient_accumulation_steps + if not evaluate: + self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu + else: + self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu + self.logger = args.logger + + self.global_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + # Initialize dataset files + if not evaluate: + self.dataset_files = [ + os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if + os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + ] + else: + self.dataset_files = [ + os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if + os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + ] + + self.dataset_files.sort() + # random.shuffle(self.dataset_files) + self.num_files = len(self.dataset_files) + # self.data_sampler = RandomSampler + self.data_sampler = DistributedSampler + + self.worker_init = WorkerInitObj(args.seed + args.local_rank) + self.dataset_future = None + self.pool = ProcessPoolExecutor(1) + self.data_file = None + self.shuffle = True + + if self.global_rank == 0: + self.logger.info( + f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}" + ) + + def get_shard(self, index): + start = time.time() + if self.dataset_future is None: + self.data_file = self._get_shard_file(index) + self.train_dataloader, sample_count = create_pretraining_dataset( + input_file=self.data_file, + max_predictions_per_seq=self.max_predictions_per_seq, + num_workers=self.num_workers, + train_batch_size=self.train_micro_batch_size_per_gpu, + worker_init=self.worker_init, + data_sampler=self.data_sampler) + else: + self.train_dataloader, sample_count = self.dataset_future.result( + timeout=None) + + self.logger.info( + f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." + ) + + return self.train_dataloader, sample_count + + def release_shard(self): + del self.train_dataloader + self.pool.shutdown() + + def prefetch_shard(self, index): + self.data_file = self._get_shard_file(index) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, self.data_file, + self.max_predictions_per_seq, self.num_workers, + self.train_micro_batch_size_per_gpu, self.worker_init, + self.data_sampler) + + def get_batch(self, batch_iter): + return batch_iter + + def prefetch_batch(self): + pass + + def _get_shard_file(self, shard_index): + file_index = self._get_shard_file_index(shard_index, self.global_rank) + return self.dataset_files[file_index] + + def _get_shard_file_index(self, shard_index, global_rank): + # if dist.is_initialized() and self.world_size > self.num_files: + # remainder = self.world_size % self.num_files + # file_index = (shard_index * self.world_size) + global_rank + ( + # remainder * shard_index) + # else: + # file_index = shard_index * self.world_size + global_rank + + return shard_index % self.num_files + + def shuffle_dataset(self, epoch): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.num_files, generator=g).tolist() + new_dataset = [self.dataset_files[i] for i in indices] + self.dataset_files = new_dataset + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py new file mode 100644 index 000000000..ba17b0f5e --- /dev/null +++ b/examples/language/roberta/pretraining/pretrain_utils.py @@ -0,0 +1,112 @@ +import transformers +import logging +from colossalai.nn.lr_scheduler import LinearWarmupLR +from transformers import get_linear_schedule_with_warmup +from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig +from transformers import GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, AutoModelForMaskedLM +from colossalai.nn.optimizer import FusedAdam +from torch.optim import AdamW +from colossalai.core import global_context as gpc +import torch +import os +import sys +sys.path.append(os.getcwd()) +from model.deberta_v2 import DebertaV2ForMaskedLM +from model.bert import BertForMaskedLM +import torch.nn as nn + +from collections import OrderedDict + +__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] + + +def get_new_state_dict(state_dict, start_index=13): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[start_index:] + new_state_dict[name] = v + return new_state_dict + + +class LMModel(nn.Module): + def __init__(self, model, config, args): + super().__init__() + + self.checkpoint = args.checkpoint_activations + self.config = config + self.model = model + if self.checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None): + # Only return lm_logits + return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +def get_model(args, logger): + + if args.mlm == 'bert': + config = transformers.BertConfig.from_json_file(args.bert_config) + model = BertForMaskedLM(config) + elif args.mlm == 'deberta_v2': + config = transformers.DebertaV2Config.from_json_file(args.bert_config) + model = DebertaV2ForMaskedLM(config) + else: + raise Exception("Invalid mlm!") + + if len(args.load_pretrain_model) > 0: + assert os.path.exists(args.load_pretrain_model) + # load_checkpoint(args.load_pretrain_model, model, strict=False) + m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + # new_state_dict = get_new_state_dict(m_state_dict) + model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! + logger.info("load model success") + + numel = sum([p.numel() for p in model.parameters()]) + if args.checkpoint_activations: + model.gradient_checkpointing_enable() + # model = LMModel(model, config, args) + + return config, model, numel + + +def get_optimizer(model, lr): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + + # configure the weight decay for bert models + optimizer_grouped_parameters = [{ + 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + 'weight_decay': 0.1 + }, { + 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0 + }] + optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) + return optimizer + + +def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): + # warmup_steps = int(total_steps * warmup_ratio) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) + # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) + return lr_scheduler + + +def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): + model_path = path + '_pytorch_model.bin' + optimizer_lr_path = path + '.op_lrs' + checkpoint = {} + checkpoint['optimizer'] = optimizer.state_dict() + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + checkpoint['epoch'] = epoch + checkpoint['shard'] = shard + checkpoint['global_step'] = global_step + model_state = model.state_dict() #each process must run model.state_dict() + if gpc.get_global_rank() == 0: + torch.save(checkpoint, optimizer_lr_path) + torch.save(model_state, model_path) + + + diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh new file mode 100644 index 000000000..144cd0ab9 --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh new file mode 100644 index 000000000..a0704cf7c --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env sh + +root_path=$PWD +PY_FILE_PATH="$root_path/run_pretraining.py" + +tensorboard_path="$root_path/tensorboard" +log_path="$root_path/exp_log" +ckpt_path="$root_path/ckpt" + +colossal_config="$root_path/../configs/colossalai_ddp.py" + +mkdir -p $tensorboard_path +mkdir -p $log_path +mkdir -p $ckpt_path + +export PYTHONPATH=$PWD + +env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ + --include GPU002,GPU003,GPU004,GPU007 \ + --nproc_per_node=8 \ + $PY_FILE_PATH \ + --master_addr GPU007 \ + --master_port 20024 \ + --lr 2.0e-4 \ + --train_micro_batch_size_per_gpu 190 \ + --eval_micro_batch_size_per_gpu 20 \ + --epoch 15 \ + --data_path_prefix /h5 \ + --eval_data_path_prefix /eval_h5 \ + --tokenizer_path /roberta \ + --bert_config /roberta/config.json \ + --tensorboard_path $tensorboard_path \ + --log_path $log_path \ + --ckpt_path $ckpt_path \ + --colossal_config $colossal_config \ + --log_interval 50 \ + --mlm bert \ + --wandb \ + --checkpoint_activations \ + --resume_train \ + --load_pretrain_model /ckpt/1.pt \ + --load_optimizer_lr /ckpt/1.op_lrs \ + \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py new file mode 100644 index 000000000..9840a122c --- /dev/null +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -0,0 +1,226 @@ +import colossalai +import math +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +import colossalai.nn as col_nn +from arguments import parse_args +from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt +from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args +from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer +from utils.logger import Logger +from evaluation import evaluate +from loss import LossForPretraining + +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from tqdm import tqdm +import os +import time +from functools import partial + +from transformers import AutoTokenizer + +from colossalai.gemini import ChunkManager, GeminiManager +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils import get_current_device +from colossalai.nn.parallel import ZeroDDP +from colossalai.zero import ZeroOptimizer +from colossalai.tensor import ProcessGroup +from colossalai.nn.optimizer import HybridAdam + + +def main(): + + args = parse_args() + launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) + + if args.vscode_debug: + colossalai.launch(config={}, + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) + args.local_rank = -1 + args.log_interval = 1 + else: + colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + args.local_rank = int(os.environ["LOCAL_RANK"]) + logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + + log_args(logger, args) + args.tokenizer = tokenizer + args.logger = logger + set_global_variables(launch_time, args.tensorboard_path) + + use_zero = hasattr(gpc.config, 'zero') + world_size = torch.distributed.get_world_size() + + # build model, optimizer and criterion + if use_zero: + shard_strategy = TensorShardStrategy() + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, + shard_param=True): + + config, model, numel = get_model(args, logger) + # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) + else: + config, model, numel = get_model(args, logger) + logger.info("no_zero") + if torch.distributed.get_rank() == 0: + os.mkdir(os.path.join(args.ckpt_path, launch_time)) + + logger.info(f'Model numel: {numel}') + + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + total_steps = steps_per_epoch * args.epoch + + # build optimizer and lr_scheduler + + start_epoch = 0 + start_shard = 0 + global_step = 0 + if args.resume_train: + assert os.path.exists(args.load_optimizer_lr) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') + o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 + optimizer = get_optimizer(model, lr=args.lr) + optimizer.load_state_dict(o_l_state_dict['optimizer']) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") + # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() + lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + + start_epoch = o_l_state_dict['epoch'] + start_shard = o_l_state_dict['shard'] + 1 + # global_step = o_l_state_dict['global_step'] + 1 + logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + else: + optimizer = get_optimizer(model, lr=args.lr) + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) + + # optimizer = gpc.config.optimizer.pop('type')( + # model.parameters(), **gpc.config.optimizer) + # optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5) + criterion = LossForPretraining(config.vocab_size) + + # build dataloader + pretrain_dataset_provider = NvidiaBertDatasetProvider(args) + + # initialize with colossalai + engine, _, _, lr_scheduelr = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + + logger.info(get_mem_info(prefix='After init model, ')) + + + best_loss = None + eval_loss = 0 + train_loss = 0 + timers = get_timers() + timers('interval_time').start() + timers('epoch_time').start() + timers('shard_time').start() + + for epoch in range(start_epoch, args.epoch): + + for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): + + dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) + # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload + if torch.distributed.get_rank() == 0: + iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + else: + iterator_data = enumerate(dataset_iterator) + + engine.train() + + for step, batch_data in iterator_data: + + # batch_data = pretrain_dataset_provider.get_batch(batch_index) + input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") + attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") + token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") + mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") + # nsp_label = batch_data[5].cuda() + + output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = engine.criterion(output.logits, mlm_label) + pretrain_dataset_provider.prefetch_batch() + + engine.backward(loss) + train_loss += loss.float().item() + # if (step + 1) % args.accumulation_step == 0: + engine.step() + lr_scheduelr.step() + engine.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0 and global_step != 0 \ + and torch.distributed.get_rank() == 0: + elapsed_time = timers('interval_time').elapsed(reset=False) + elapsed_time_per_iteration = elapsed_time / global_step + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + + cur_loss = train_loss / args.log_interval + current_lr = lr_scheduelr.get_last_lr()[0] + log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ + f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + logger.info(log_str, print_=False) + + if args.wandb: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_train({ + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) + + train_loss = 0 + + logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') + logger.info('*' * 100) + + eval_loss += evaluate(engine, args, logger, global_step) + save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) + + + eval_loss /= len(os.listdir(args.data_path_prefix)) + logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info('-' * 100) + if args.wandb and torch.distributed.get_rank() == 0: + tensorboard_log = get_tensorboard_writer() + tensorboard_log.log_eval({ + 'all_eval_shard_loss': eval_loss, + }, epoch) + start_shard = 0 + eval_loss = 0 + + pretrain_dataset_provider.release_shard() + + logger.info('Congratulation, training has finished!!!') + + +if __name__ == '__main__': + main() diff --git a/examples/language/roberta/pretraining/utils/WandbLog.py b/examples/language/roberta/pretraining/utils/WandbLog.py new file mode 100644 index 000000000..9dd28a981 --- /dev/null +++ b/examples/language/roberta/pretraining/utils/WandbLog.py @@ -0,0 +1,46 @@ +import time +import wandb +import os +from torch.utils.tensorboard import SummaryWriter + +class WandbLog: + + @classmethod + def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + wandb.init(project=project, notes=notes, name=name, config=config) + + @classmethod + def log(cls, result, model=None, gradient=None): + wandb.log(result) + + if model: + wandb.watch(model) + + if gradient: + wandb.watch(gradient) + + +class TensorboardLog: + + def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): + if not os.path.exists(location): + os.mkdir(location) + self.writer = SummaryWriter(location, comment=name) + + def log_train(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/train', v, step) + + def log_eval(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}/eval', v, step) + + def log_zeroshot(self, result, step): + for k, v in result.items(): + self.writer.add_scalar(f'{k}_acc/eval', v, step) + + + + + + diff --git a/examples/language/roberta/pretraining/utils/exp_util.py b/examples/language/roberta/pretraining/utils/exp_util.py new file mode 100644 index 000000000..a02b0872a --- /dev/null +++ b/examples/language/roberta/pretraining/utils/exp_util.py @@ -0,0 +1,99 @@ +import functools +import os, shutil +import torch +import psutil +from colossalai.core import global_context as gpc + +def logging(s, log_path, print_=True, log_=True): + if print_: + print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + +def create_exp_dir(dir_path, scripts_to_save=None, debug=False): + if debug: + print('Debug Mode : no experiment dir created') + return functools.partial(logging, log_path=None, log_=False) + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + print('Experiment dir : {}'.format(dir_path)) + if scripts_to_save is not None: + script_path = os.path.join(dir_path, 'scripts') + if not os.path.exists(script_path): + os.makedirs(script_path) + for script in scripts_to_save: + dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_parameters_in_billions(model, world_size=1): + gpus_per_model = world_size + + approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) + for model_module in model]) + + return approx_parameters_in_billions * gpus_per_model / (1e9) + +def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): + gpus_per_model = 1 + batch_size = args.train_micro_batch_size_per_gpu + samples_per_model = batch_size * args.max_seq_length + model_replica_count = world_size / gpus_per_model + approx_parameters_in_billions = numel + elapsed_time_per_iter = iteration_time / total_iterations + samples_per_second = batch_size / elapsed_time_per_iter + + #flops calculator + hidden_size = config.hidden_size + num_layers = config.num_hidden_layers + vocab_size = config.vocab_size + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3. + checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) + return samples_per_second, tflops, approx_parameters_in_billions + +def synchronize(): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_intialized(): + return + world_size = torch.distributed.get_world_size() + if world_size == 1: + return + torch.distributed.barrier() + +def log_args(logger, args): + logger.info('--------args----------') + message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) + message += '\n' + message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info(message) + logger.info('--------args----------\n') \ No newline at end of file diff --git a/examples/language/roberta/pretraining/utils/global_vars.py b/examples/language/roberta/pretraining/utils/global_vars.py new file mode 100644 index 000000000..363cbf91c --- /dev/null +++ b/examples/language/roberta/pretraining/utils/global_vars.py @@ -0,0 +1,126 @@ +import time +import torch +from .WandbLog import TensorboardLog + +_GLOBAL_TIMERS = None +_GLOBAL_TENSORBOARD_WRITER = None + + +def set_global_variables(launch_time, tensorboard_path): + _set_timers() + _set_tensorboard_writer(launch_time, tensorboard_path) + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + +def _set_tensorboard_writer(launch_time, tensorboard_path): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + if torch.distributed.get_rank() == 0: + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + # assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) diff --git a/examples/language/roberta/pretraining/utils/logger.py b/examples/language/roberta/pretraining/utils/logger.py new file mode 100644 index 000000000..481c4c6ce --- /dev/null +++ b/examples/language/roberta/pretraining/utils/logger.py @@ -0,0 +1,31 @@ +import os +import logging +import torch.distributed as dist + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Logger(): + def __init__(self, log_path, cuda=False, debug=False): + self.logger = logging.getLogger(__name__) + self.cuda = cuda + self.log_path = log_path + self.debug = debug + + + def info(self, message, log_=True, print_=True, *args, **kwargs): + if (self.cuda and dist.get_rank() == 0) or not self.cuda: + if print_: + self.logger.info(message, *args, **kwargs) + + if log_: + with open(self.log_path, 'a+') as f_log: + f_log.write(message + '\n') + + + def error(self, message, *args, **kwargs): + self.logger.error(message, *args, **kwargs) diff --git a/examples/language/roberta/requirements.txt b/examples/language/roberta/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/language/roberta/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/.gitignore b/examples/tutorial/.gitignore new file mode 100644 index 000000000..f873b6a4a --- /dev/null +++ b/examples/tutorial/.gitignore @@ -0,0 +1 @@ +./data/ diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md new file mode 100644 index 000000000..bef7c8905 --- /dev/null +++ b/examples/tutorial/README.md @@ -0,0 +1,193 @@ +# Colossal-AI Tutorial Hands-on + +## Introduction + +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc. + + +[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates +many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management, +large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and +quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment. + +### 🚀 Quick Links + +[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) | +[**Paper**](https://arxiv.org/abs/2110.14883) | +[**Documentation**](https://www.colossalai.org/) | +[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | +[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + +## Table of Content + + - Multi-dimensional Parallelism + - Know the components and sketch of Colossal-AI + - Step-by-step from PyTorch to Colossal-AI + - Try data/pipeline parallelism and 1D/2D/2.5D/3D tensor parallelism using a unified model + - Sequence Parallelism + - Try sequence parallelism with BERT + - Combination of data/pipeline/sequence parallelism + - Faster training and longer sequence length + - Large Batch Training Optimization + - Comparison of small/large batch size with SGD/LARS optimizer + - Acceleration from a larger batch size + - Auto-Parallelism + - Parallelism with normal non-distributed training code + - Model tracing + solution solving + runtime communication inserting all in one auto-parallelism system + - Try single program, multiple data (SPMD) parallel with auto-parallelism SPMD solver on ResNet50 + - Fine-tuning and Serving for OPT + - Try pre-trained OPT model weights with Colossal-AI + - Fine-tuning OPT with limited hardware using ZeRO, Gemini and parallelism + - Deploy the fine-tuned model to inference service + - Acceleration of Stable Diffusion + - Stable Diffusion with Lightning + - Try Lightning Colossal-AI strategy to optimize memory and accelerate speed + + +## Discussion + +Discussion about the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project is always welcomed! We would love to exchange ideas with the community to better help this project grow. +If you think there is a need to discuss anything, you may jump to our [Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w). + +If you encounter any problem while running these tutorials, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository. + +## 🛠️ Setup environment +You should use `conda` to create a virtual environment, we recommend **python 3.8**, e.g. `conda create -n colossal python=3.8`. This installation commands are for CUDA 11.3, if you have a different version of CUDA, please download PyTorch and Colossal-AI accordingly. + +``` +# install torch +# visit https://pytorch.org/get-started/locally/ to download other versions +pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 + +# install latest ColossalAI +# visit https://colossalai.org/download to download corresponding version of Colossal-AI +pip install colossalai==0.1.11rc3+torch1.12cu11.3 -f https://release.colossalai.org +``` + +You can run `colossalai check -i` to verify if you have correctly set up your environment 🕹️. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/colossalai%20check%20-i.png) + +If you encounter messages like `please install with cuda_ext`, do let me know as it could be a problem of the distribution wheel. 😥 + +Then clone the Colossal-AI repository from GitHub. +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/examples/tutorial +``` + +## 🔥 Multi-dimensional Hybrid Parallel with Vision Transformer +1. Go to **hybrid_parallel** folder in the **tutorial** directory. +2. Install our model zoo. +```bash +pip install titans +``` +3. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag. +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. + +## ☀️ Sequence Parallel with BERT +1. Go to the **sequence_parallel** folder in the **tutorial** directory. +2. Run with the following command +```bash +export PYTHONPATH=$PWD +colossalai run --nproc_per_node 4 train.py -s +``` +3. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again. + +## 📕 Large batch optimization with LARS and LAMB +1. Go to the **large_batch_optimizer** folder in the **tutorial** directory. +2. Run with synthetic data +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +## 😀 Auto-Parallel Tutorial +1. Go to the **auto_parallel** folder in the **tutorial** directory. +2. Install `pulp` and `coin-or-cbc` for the solver. +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` +2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset. +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s +``` + +You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-parallel%20demo.png) + +## 🎆 Auto-Checkpoint Tutorial +1. Stay in the `auto_parallel` folder. +2. Install the dependencies. +```bash +pip install matplotlib transformers +``` +3. Run a simple resnet50 benchmark to automatically checkpoint the model. +```bash +python auto_ckpt_solver_test.py --model resnet50 +``` + +You should expect the log to be like this +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20demo.png) + +This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger. +```bash +python auto_ckpt_solver_test.py --model gpt2 +``` + +4. Run a simple benchmark to find the optimal batch size for checkpointed model. +```bash +python auto_ckpt_batchsize_test.py +``` + +You can expect the log to be like +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20batchsize.png) + +## 🚀 Run OPT finetuning and inference +1. Install the dependency +```bash +pip install datasets accelerate +``` +2. Run finetuning with synthetic datasets with one GPU +```bash +bash ./run_clm_synthetic.sh +``` +3. Run finetuning with 4 GPUs +```bash +bash ./run_clm_synthetic.sh 16 0 125m 4 +``` +4. Run inference with OPT 125M +```bash +docker hpcaitech/tutorial:opt-inference +docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference +``` +5. Start the http server inside the docker container with tensor parallel size 2 +```bash +python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m +``` + +## 🖼️ Accelerate Stable Diffusion with Colossal-AI +1. Create a new environment for diffusion +```bash +conda env create -f environment.yaml +conda activate ldm +``` +2. Install Colossal-AI from our official page +```bash +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` +3. Install PyTorch Lightning compatible commit +```bash +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +cd .. +``` + +4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`. +5. Run training with CIFAR10. +```bash +python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml +``` diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md new file mode 100644 index 000000000..e99a018c2 --- /dev/null +++ b/examples/tutorial/auto_parallel/README.md @@ -0,0 +1,106 @@ +# Auto-Parallelism with ResNet + +## 🚀Quick Start +### Auto-Parallel Tutorial +1. Install `pulp` and `coin-or-cbc` for the solver. +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` +2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset. +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s +``` + +You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training. +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-parallel%20demo.png) + + +### Auto-Checkpoint Tutorial +1. Stay in the `auto_parallel` folder. +2. Install the dependencies. +```bash +pip install matplotlib transformers +``` +3. Run a simple resnet50 benchmark to automatically checkpoint the model. +```bash +python auto_ckpt_solver_test.py --model resnet50 +``` + +You should expect the log to be like this +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20demo.png) + +This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger. +```bash +python auto_ckpt_solver_test.py --model gpt2 +``` + +4. Run a simple benchmark to find the optimal batch size for checkpointed model. +```bash +python auto_ckpt_batchsize_test.py +``` + +You can expect the log to be like +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/tutorial/auto-ckpt%20batchsize.png) + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + +## extra requirements to use autoparallel + +```bash +pip install pulp +conda install coin-or-cbc +``` + +## Run on 2*2 device mesh + +```bash +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py +``` + +## Auto Checkpoint Benchmarking + +We prepare two bechmarks for you to test the performance of auto checkpoint + +The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + +The second test `auto_ckpt_batchsize_test.py` will show you the advantage of fitting larger batchsize training into limited GPU memory with the help of our activation checkpoint solver (test on ResNet152). It will output the benchmark summary. + +The usage of the above two test +```bash +# run auto_ckpt_solver_test.py on gpt2 medium +python auto_ckpt_solver_test.py --model gpt2 + +# run auto_ckpt_solver_test.py on resnet50 +python auto_ckpt_solver_test.py --model resnet50 + +# tun auto_ckpt_batchsize_test.py +python auto_ckpt_batchsize_test.py +``` + +There are some results for your reference + +## Auto Checkpoint Solver Test + +### ResNet 50 +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/resnet50_benchmark.png) + +### GPT2 Medium +![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/gpt2_benchmark.png) + +## Auto Checkpoint Batch Size Test +```bash +===============test summary================ +batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s +batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s +batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s +``` diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py new file mode 100644 index 000000000..5decfc695 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -0,0 +1,59 @@ +import time +from argparse import ArgumentParser +from copy import deepcopy +from functools import partial + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import bench, data_gen_resnet + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def _benchmark(rank, world_size, port): + """Auto activation checkpoint batchsize benchmark + This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of + maximum GPU memory, and with the batch size of [512, 1024, 2048], you could see that using auto activation + checkpoint with optimality guarantee, we might be able to find better batch size for the model, as larger batch + size means that we are able to use larger portion of GPU FLOPS, while recomputation scheduling with our solver + only result in minor performance drop. So at last we might be able to find better training batch size for our + model (combine with large batch training optimizer such as LAMB). + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = tm.resnet152() + gm = symbolic_trace(model) + raw_graph = deepcopy(gm.graph) + peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] + for batch_size in batch_sizes: + batch_size = int(batch_size) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) + gm.graph = solver.solve() + peak_mem, step_time = bench(gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5) + peak_mems.append(peak_mem) + through_puts.append(batch_size / step_time * 1.0e3) + gm.graph = deepcopy(raw_graph) + + # print results + print("===============benchmark summary================") + for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): + print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + + +def auto_activation_checkpoint_batchsize_benchmark(): + world_size = 1 + run_func_module = partial(_benchmark, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + auto_activation_checkpoint_batchsize_benchmark() diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py new file mode 100644 index 000000000..ab0f2ef66 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -0,0 +1,89 @@ +import time +from argparse import ArgumentParser +from functools import partial + +import matplotlib.pyplot as plt +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium + +import colossalai +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace, symbolic_trace +from colossalai.utils import free_port + + +def _benchmark(rank, world_size, port, args): + """ + Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50. + The benchmark will sample in a range of memory budget for each model and output the benchmark summary and + data visualization of peak memory vs. budget memory and relative step time vs. peak memory. + """ + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if args.model == 'resnet50': + model = tm.resnet50() + data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) + gm = symbolic_trace(model) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + loss = torch.nn.CrossEntropyLoss() + else: + model = gpt2_medium() + data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) + data, mask = data_gen(device='meta')[0] + gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = metainfo_trace(gm, data, mask) + loss = GPTLMLoss() + + free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 + start_factor = 4 if args.model == 'resnet50' else 10 + + # trace and benchmark + budgets, peak_hist, step_hist = bench_rotor(gm, + loss, + data_gen, + num_steps=5, + sample_points=15, + free_memory=free_memory, + start_factor=start_factor) + + # print summary + print("==============benchmark summary==============") + for budget, peak, step in zip(budgets, peak_hist, step_hist): + print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + + # plot valid results + fig, axs = plt.subplots(1, 2, figsize=(16, 8)) + valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) + + # plot peak memory vs. budget memory + axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].set_xlabel("Budget Memory (MB)") + axs[0].set_ylabel("Peak Memory (MB)") + axs[0].set_title("Peak Memory vs. Budget Memory") + + # plot relative step time vs. budget memory + axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].set_xlabel("Peak Memory (MB)") + axs[1].set_ylabel("Relative Step Time") + axs[1].set_title("Step Time vs. Peak Memory") + axs[1].set_ylim(0.8, 1.5) + + # save plot + fig.savefig(f"{args.model}_benchmark.png") + + +def auto_activation_checkpoint_benchmark(args): + world_size = 1 + run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == "__main__": + parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") + parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + args = parser.parse_args() + + auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py new file mode 100644 index 000000000..e4aff13e4 --- /dev/null +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -0,0 +1,200 @@ +import argparse +import os +from pathlib import Path + +import torch +from titans.utils import barrier_context +from torch.fx import GraphModule +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet50 +from tqdm import tqdm + +import colossalai +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingLR +from colossalai.utils import get_dataloader + +DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10") + return parser.parse_args() + + +def synthesize_data(): + img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) + label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) + return img, label + + +def main(): + args = parse_args() + colossalai.launch_from_torch(config='./config.py') + + logger = get_dist_logger() + + if not args.synthetic: + with barrier_context(): + # build dataloaders + train_dataset = CIFAR10(root=DATA_ROOT, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ])) + + test_dataset = CIFAR10(root=DATA_ROOT, + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + + train_dataloader = get_dataloader( + dataset=train_dataset, + add_sampler=True, + shuffle=True, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + + test_dataloader = get_dataloader( + dataset=test_dataset, + add_sampler=True, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + else: + train_dataloader, test_dataloader = None, None + + # initialize device mesh + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # trace the model with meta data + tracer = ColoTracer() + model = resnet50(num_classes=10).cuda() + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # prepare info for solver + solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + + # solve the solution + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + if gpc.get_global_rank() == 0: + for index, node in enumerate(graph.nodes): + print(node.name, node.strategies_vector[solution[index]].name) + + # process the graph for distributed training ability + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + + # lr_scheduler + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) + + for epoch in range(gpc.config.NUM_EPOCHS): + gm.train() + + if args.synthetic: + # if we use synthetic data + # we assume it only has 30 steps per epoch + num_steps = range(30) + + else: + # we use the actual number of steps for training + num_steps = range(len(train_dataloader)) + data_iter = iter(train_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + + img = img.cuda() + label = label.cuda() + optimizer.zero_grad() + output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + train_loss = criterion(output, label) + train_loss.backward(train_loss) + optimizer.step() + lr_scheduler.step() + + # run evaluation + gm.eval() + correct = 0 + total = 0 + + if args.synthetic: + # if we use synthetic data + # we assume it only has 10 steps for evaluation + num_steps = range(30) + + else: + # we use the actual number of steps for training + num_steps = range(len(test_dataloader)) + data_iter = iter(test_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + + img = img.cuda() + label = label.cuda() + + with torch.no_grad(): + output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + test_loss = criterion(output, label) + pred = torch.argmax(output, dim=-1) + correct += torch.sum(pred == label) + total += img.size(0) + + logger.info( + f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", + ranks=[0]) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py new file mode 100644 index 000000000..69859f885 --- /dev/null +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -0,0 +1,170 @@ +import time +from copy import deepcopy +from functools import partial +from typing import Callable, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torchvision.models as tm +from transformers import GPT2Config, GPT2LMHeadModel + +from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor +from colossalai.fx import metainfo_trace + + +def bench(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5) -> Tuple[int, int]: + """Benchmarking a given graph module + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + Returns: + Tuple[int, int]: peak memory in MB and step time in MS. + """ + gm.train() + gm.cuda() + step_time = float('inf') + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + cached = torch.cuda.max_memory_allocated(device="cuda") + try: + for _ in range(num_steps): + args, label = data_gen() + output, loss = None, None + + torch.cuda.synchronize(device="cuda") + start = time.time() + output = gm(*args) + loss = criterion(output, label) + loss.backward() + torch.cuda.synchronize(device="cuda") + step_time = min(step_time, time.time() - start) + + for child in gm.children(): + for param in child.parameters(): + param.grad = None + del args, label, output, loss + except: + del args, label, output, loss + gm.to("cpu") + torch.cuda.empty_cache() + peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 + return peak_mem, step_time * 1.0e3 + + +def bench_rotor(gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4) -> Tuple[np.array, list, list]: + """Auto Checkpoint Rotor Algorithm benchmarking + Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. + Args: + gm (torch.fx.GraphModule): The graph module to benchmark. + criterion (torch.nn.Module): Loss function. + data_gen (Callable): Data generator. + num_steps (int, optional): Number of test steps. Defaults to 5. + sample_points (int, optional): Number of sample points. Defaults to 20. + free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. + start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget + will be free_memory / start_factor. Defaults to 4. + Returns: + Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). + """ + peak_hist, step_hist = [], [] + raw_graph = deepcopy(gm.graph) + for budget in np.linspace(free_memory // start_factor, free_memory, sample_points): + gm = metainfo_trace(gm, *data_gen()[0]) + solver = CheckpointSolverRotor(gm.graph, free_memory=budget) + try: + gm.graph = solver.solve(verbose=False) + peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) + except: + peak_memory, step_time = budget / 1024**2, float('inf') + peak_hist.append(peak_memory) + step_hist.append(step_time) + gm.graph = deepcopy(raw_graph) + return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist + + +class GPTLMModel(nn.Module): + """ + GPT Model + """ + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + """ + GPT Loss + """ + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_6b(checkpoint=False): + return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) + + +def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): + """ + Generate random data for gpt2 benchmarking + """ + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return (input_ids, attention_mask), attention_mask + + +def data_gen_resnet(batch_size, shape, device='cuda:0'): + """ + Generate random data for resnet benchmarking + """ + data = torch.empty(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return (data,), label diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py new file mode 100644 index 000000000..fa14eda74 --- /dev/null +++ b/examples/tutorial/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 128 +NUM_EPOCHS = 10 diff --git a/examples/tutorial/auto_parallel/requirements.txt b/examples/tutorial/auto_parallel/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/tutorial/auto_parallel/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py new file mode 100644 index 000000000..5c6b6988a --- /dev/null +++ b/examples/tutorial/download_cifar10.py @@ -0,0 +1,13 @@ +import os + +from torchvision.datasets import CIFAR10 + + +def main(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + data_root = os.path.join(dir_path, 'data') + dataset = CIFAR10(root=data_root, download=True) + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/hybrid_parallel/README.md b/examples/tutorial/hybrid_parallel/README.md new file mode 100644 index 000000000..6f975e863 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/README.md @@ -0,0 +1,45 @@ +# Multi-dimensional Parallelism with Colossal-AI + + +## 🚀Quick Start +1. Install our model zoo. +```bash +pip install titans +``` +2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag. +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + +3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. + + +## Install Titans Model Zoo + +```bash +pip install titans +``` + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + + +## Run on 2*2 device mesh + +Current configuration setting on `config.py` is TP=2, PP=2. + +```bash +# train with cifar10 +colossalai run --nproc_per_node 4 train.py --config config.py + +# train with synthetic data +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py new file mode 100644 index 000000000..ac273c305 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 2 +WARMUP_EPOCHS = 1 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/hybrid_parallel/requirements.txt b/examples/tutorial/hybrid_parallel/requirements.txt new file mode 100644 index 000000000..dbf6aaf3e --- /dev/null +++ b/examples/tutorial/hybrid_parallel/requirements.txt @@ -0,0 +1,3 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +titans \ No newline at end of file diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh new file mode 100644 index 000000000..8860b72a2 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/test_ci.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -euxo pipefail + +pip install -r requirements.txt +torchrun --standalone --nproc_per_node 4 train.py --config config.py -s diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py new file mode 100644 index 000000000..2a8576db7 --- /dev/null +++ b/examples/tutorial/hybrid_parallel/train.py @@ -0,0 +1,145 @@ +import os + +import torch +from titans.dataloader.cifar10 import build_cifar +from titans.model.vit.vit import _create_vit_model +from tqdm import tqdm + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.utils import get_dataloader, is_using_pp + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data') + if args.synthetic: + # if we use synthetic dataset + # we train for 10 steps and eval for 5 steps per epoch + train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/large_batch_optimizer/README.md b/examples/tutorial/large_batch_optimizer/README.md new file mode 100644 index 000000000..20bddb383 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/README.md @@ -0,0 +1,31 @@ +# Comparison of Large Batch Training Optimization + +## 🚀Quick Start +Run with synthetic data +```bash +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`. +The dataset will be downloaded to `colossalai/examples/tutorials/data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + +You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command. + + +## Run on 2*2 device mesh + +```bash +# run with cifar10 +colossalai run --nproc_per_node 4 train.py --config config.py + +# run with synthetic dataset +colossalai run --nproc_per_node 4 train.py --config config.py -s +``` diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py new file mode 100644 index 000000000..e019154e4 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/large_batch_optimizer/requirements.txt b/examples/tutorial/large_batch_optimizer/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py new file mode 100644 index 000000000..d403c275d --- /dev/null +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -0,0 +1,144 @@ +import os + +import torch +from titans.dataloader.cifar10 import build_cifar +from titans.model.vit.vit import _create_vit_model +from tqdm import tqdm + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import Lamb, Lars +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.utils import get_dataloader, is_using_pp + + +class DummyDataloader(): + + def __init__(self, length, batch_size): + self.length = length + self.batch_size = batch_size + + def generate(self): + data = torch.rand(self.batch_size, 3, 224, 224) + label = torch.randint(low=0, high=10, size=(self.batch_size,)) + return data, label + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data/') + if args.synthetic: + train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE) + test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + data_iter = iter(train_dataloader) + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/opt/inference/README.md b/examples/tutorial/opt/inference/README.md new file mode 100644 index 000000000..5bacac0d7 --- /dev/null +++ b/examples/tutorial/opt/inference/README.md @@ -0,0 +1,88 @@ +# Overview + +This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI. + +It supports tensor parallelism, batching and caching. + +## 🚀Quick Start +1. Run inference with OPT 125M +```bash +docker hpcaitech/tutorial:opt-inference +docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference +``` +2. Start the http server inside the docker container with tensor parallel size 2 +```bash +python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m +``` + +# How to run + +Run OPT-125M: +```shell +python opt_fastapi.py opt-125m +``` + +It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs. + +## Configure + +### Configure model +```shell +python opt_fastapi.py +``` +Available models: opt-125m, opt-6.7b, opt-30b, opt-175b. + +### Configure tensor parallelism +```shell +python opt_fastapi.py --tp +``` +The `` can be an integer in `[1, #GPUs]`. Default `1`. + +### Configure checkpoint +```shell +python opt_fastapi.py --checkpoint +``` +The `` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded. + +### Configure queue +```shell +python opt_fastapi.py --queue_size +``` +The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406). + +### Configure bathcing +```shell +python opt_fastapi.py --max_batch_size +``` +The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. + +Note that the batch size is not always equal to ``, as some consecutive requests may not be batched. + +### Configure caching +```shell +python opt_fastapi.py --cache_size --cache_list_size +``` +This will cache `` unique requests. And for each unique request, it cache `` different results. A random result will be returned if the cache is hit. + +The `` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `` can be an integer in `[1, MAXINT]`. + +### Other configurations +```shell +python opt_fastapi.py -h +``` + +# How to benchmark +```shell +cd benchmark +locust +``` + +Then open the web interface link which is on your console. + +# Pre-process pre-trained weights + +## OPT-66B +See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py). + +## OPT-175B +See [script/process-opt-175b](./script/process-opt-175b/). \ No newline at end of file diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py new file mode 100644 index 000000000..1a0876ca8 --- /dev/null +++ b/examples/tutorial/opt/inference/batch.py @@ -0,0 +1,59 @@ +import torch +from typing import List, Deque, Tuple, Hashable, Any +from energonai import BatchManager, SubmitEntry, TaskEntry + + +class BatchManagerForGeneration(BatchManager): + def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.pad_token_id = pad_token_id + + def _left_padding(self, batch_inputs): + max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) + outputs = {'input_ids': [], 'attention_mask': []} + for inputs in batch_inputs: + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + padding_len = max_len - len(input_ids) + input_ids = [self.pad_token_id] * padding_len + input_ids + attention_mask = [0] * padding_len + attention_mask + outputs['input_ids'].append(input_ids) + outputs['attention_mask'].append(attention_mask) + for k in outputs: + outputs[k] = torch.tensor(outputs[k]) + return outputs, max_len + + @staticmethod + def _make_batch_key(entry: SubmitEntry) -> tuple: + data = entry.data + return (data['top_k'], data['top_p'], data['temperature']) + + def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: + entry = q.popleft() + uids = [entry.uid] + batch = [entry.data] + while len(batch) < self.max_batch_size: + if len(q) == 0: + break + if self._make_batch_key(entry) != self._make_batch_key(q[0]): + break + if q[0].data['max_tokens'] > entry.data['max_tokens']: + break + e = q.popleft() + batch.append(e.data) + uids.append(e.uid) + inputs, max_len = self._left_padding(batch) + trunc_lens = [] + for data in batch: + trunc_lens.append(max_len + data['max_tokens']) + inputs['top_k'] = entry.data['top_k'] + inputs['top_p'] = entry.data['top_p'] + inputs['temperature'] = entry.data['temperature'] + inputs['max_tokens'] = max_len + entry.data['max_tokens'] + return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + + def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: + retval = [] + for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): + retval.append((uid, output[:trunc_len])) + return retval diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py new file mode 100644 index 000000000..4d829e5d8 --- /dev/null +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -0,0 +1,15 @@ +from locust import HttpUser, task +from json import JSONDecodeError + + +class GenerationUser(HttpUser): + @task + def generate(self): + prompt = 'Question: What is the longest river on the earth? Answer:' + for i in range(4, 9): + data = {'max_tokens': 2**i, 'prompt': prompt} + with self.client.post('/generation', json=data, catch_response=True) as response: + if response.status_code in (200, 406): + response.success() + else: + response.failure('Response wrong') diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py new file mode 100644 index 000000000..30febc44f --- /dev/null +++ b/examples/tutorial/opt/inference/cache.py @@ -0,0 +1,64 @@ +from collections import OrderedDict +from threading import Lock +from contextlib import contextmanager +from typing import List, Any, Hashable, Dict + + +class MissCacheError(Exception): + pass + + +class ListCache: + def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: + """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. + When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. + + Args: + cache_size (int): Max size for LRU cache. + list_size (int): Value list size. + fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. + """ + self.cache_size = cache_size + self.list_size = list_size + self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() + self.fixed_cache: Dict[Hashable, List[Any]] = {} + for key in fixed_keys: + self.fixed_cache[key] = [] + self._lock = Lock() + + def get(self, key: Hashable) -> List[Any]: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) >= self.list_size: + return l + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) >= self.list_size: + return l + raise MissCacheError() + + def add(self, key: Hashable, value: Any) -> None: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + else: + if len(self.cache) >= self.cache_size: + self.cache.popitem(last=False) + self.cache[key] = [value] + + @contextmanager + def lock(self): + try: + self._lock.acquire() + yield + finally: + self._lock.release() diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py new file mode 100644 index 000000000..cbfc2a22e --- /dev/null +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -0,0 +1,123 @@ +import argparse +import logging +import random +from typing import Optional + +import uvicorn +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Field +from transformers import GPT2Tokenizer + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = FastAPI() + + +@app.post('/generation') +async def generate(data: GenerationTaskReq, request: Request): + logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') + key = (data.prompt, data.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(data.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = data.max_tokens + inputs['top_k'] = data.top_k + inputs['top_p'] = data.top_p + inputs['temperature'] = data.temperature + try: + uid = id(data) + engine.submit(uid, inputs) + output = await engine.wait(uid) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + raise HTTPException(status_code=406, detail=e.args[0]) + + return {'text': output} + + +@app.on_event("shutdown") +async def shutdown(*_): + engine.shutdown() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + config = uvicorn.Config(app, host=args.http_host, port=args.http_port) + server = uvicorn.Server(config=config) + server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py new file mode 100644 index 000000000..8dab82622 --- /dev/null +++ b/examples/tutorial/opt/inference/opt_server.py @@ -0,0 +1,122 @@ +import logging +import argparse +import random +from torch import Tensor +from pydantic import BaseModel, Field +from typing import Optional +from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B +from transformers import GPT2Tokenizer +from energonai import launch_engine, QueueFullError +from sanic import Sanic +from sanic.request import Request +from sanic.response import json +from sanic_ext import validate, openapi +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = Sanic('opt') + + +@app.post('/generation') +@openapi.body(GenerationTaskReq) +@validate(json=GenerationTaskReq) +async def generate(request: Request, body: GenerationTaskReq): + logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}') + key = (body.prompt, body.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(body.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = body.max_tokens + inputs['top_k'] = body.top_k + inputs['top_p'] = body.top_p + inputs['temperature'] = body.temperature + try: + uid = id(body) + engine.submit(uid, inputs) + output = await engine.wait(uid) + assert isinstance(output, Tensor) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + return json({'detail': e.args[0]}, status=406) + + return json({'text': output}) + + +@app.after_server_stop +def shutdown(*_): + engine.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/requirements.txt b/examples/tutorial/opt/inference/requirements.txt new file mode 100644 index 000000000..e6e8511e3 --- /dev/null +++ b/examples/tutorial/opt/inference/requirements.txt @@ -0,0 +1,9 @@ +fastapi==0.85.1 +locust==2.11.0 +pydantic==1.10.2 +sanic==22.9.0 +sanic_ext==22.9.0 +torch>=1.10.0 +transformers==4.23.1 +uvicorn==0.19.0 +colossalai diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md new file mode 100644 index 000000000..bc3cba72d --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -0,0 +1,46 @@ +# Process OPT-175B weights + +You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this. + +First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`. + +Then, `cd metaseq`. + +To consolidate checkpoints to eliminate FSDP: + +```shell +bash metaseq/scripts/reshard_mp_launch_no_slurm.sh /checkpoint_last / 8 1 +``` + +You will get 8 files in ``, and you should have the following checksums: +``` +7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt +c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt +45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt +abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt +05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt +d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt +fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt +2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt +``` + +Copy `flat-meta.json` to ``. + +Then cd to this dir, and we unflatten parameters. + +```shell +bash unflat.sh / / +``` + +Finally, you will get 8 files in `` with following checksums: +``` +6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt +58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt +69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt +002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt +6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt +93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt +5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt +f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt +``` + diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py new file mode 100644 index 000000000..a17ddd4fa --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -0,0 +1,55 @@ +import argparse +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch + + +def load_json(path: str): + with open(path) as f: + return json.load(f) + + +def parse_shape_info(flat_dir: str): + data = load_json(os.path.join(flat_dir, 'shape.json')) + flat_info = defaultdict(lambda: defaultdict(list)) + for k, shape in data.items(): + matched = re.match(r'decoder.layers.\d+', k) + if matched is None: + flat_key = 'flat_param_0' + else: + flat_key = f'{matched[0]}.flat_param_0' + flat_info[flat_key]['names'].append(k) + flat_info[flat_key]['shapes'].append(shape) + flat_info[flat_key]['numels'].append(int(np.prod(shape))) + return flat_info + + +def convert(flat_dir: str, output_dir: str, part: int): + flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') + output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') + flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_sd = torch.load(flat_path) + print(f'Loaded flat state dict from {flat_path}') + output_sd = {} + for flat_key, param_meta in flat_meta.items(): + flat_param = flat_sd['model'][flat_key] + assert sum(param_meta['numels']) == flat_param.numel( + ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' + for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + output_sd[name] = param.view(shape) + + torch.save(output_sd, output_path) + print(f'Saved unflat state dict to {output_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('flat_dir') + parser.add_argument('output_dir') + parser.add_argument('part', type=int) + args = parser.parse_args() + convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json new file mode 100644 index 000000000..59d285565 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -0,0 +1 @@ +{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh new file mode 100644 index 000000000..cc5c190e2 --- /dev/null +++ b/examples/tutorial/opt/inference/script/process-opt-175b/unflat.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +for i in $(seq 0 7); do + python convert_ckpt.py $1 $2 ${i} & +done + +wait $(jobs -p) diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py new file mode 100644 index 000000000..0494647d7 --- /dev/null +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -0,0 +1,55 @@ +import os +import torch +from multiprocessing import Pool + +# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main +# you can use whether wget or git lfs + +path = "/path/to/your/ckpt" +new_path = "/path/to/the/processed/ckpt/" + +assert os.path.isdir(path) +files = [] +for filename in os.listdir(path): + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + files.append(filepath) + +with Pool(14) as pool: + ckpts = pool.map(torch.load, files) + +restored = {} +for ckpt in ckpts: + for k,v in ckpt.items(): + if(k[0] == 'm'): + k = k[6:] + if(k == "lm_head.weight"): + k = "head.dense.weight" + if(k == "decoder.final_layer_norm.weight"): + k = "decoder.layer_norm.weight" + if(k == "decoder.final_layer_norm.bias"): + k = "decoder.layer_norm.bias" + restored[k] = v +restored["decoder.version"] = "0.0" + + +split_num = len(restored.keys()) // 60 +count = 0 +file_count = 1 +tmp = {} +for k,v in restored.items(): + print(k) + tmp[k] = v + count = count + 1 + if(count == split_num): + filename = str(file_count) + "-restored.pt" + torch.save(tmp, os.path.join(new_path, filename)) + file_count = file_count + 1 + count = 0 + tmp = {} + +filename = str(file_count) + "-restored.pt" +torch.save(tmp, os.path.join(new_path, filename)) + + + diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md new file mode 100644 index 000000000..a01209cbd --- /dev/null +++ b/examples/tutorial/opt/opt/README.md @@ -0,0 +1,76 @@ + +# Train OPT model with Colossal-AI + + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## 🚀Quick Start for Tutorial +1. Install the dependency +```bash +pip install datasets accelerate +``` +2. Run finetuning with synthetic datasets with one GPU +```bash +bash ./run_clm_synthetic.sh +``` +3. Run finetuning with 4 GPUs +```bash +bash ./run_clm_synthetic.sh 16 0 125m 4 +``` + +## Quick Start for Practical Use +You can launch training by using the following bash script + +```bash +bash ./run_clm.sh +``` + +- batch-size-per-gpu: number of samples fed to each GPU, default is 16 +- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) +- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request +the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). +- gpu-num: the number of GPUs to use, default is 1. + +It uses `wikitext` dataset. + +To use synthetic dataset: + +```bash +bash ./run_clm_synthetic.sh +``` + +## Remarkable Performance +On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. +Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. + +

+ +

+ +Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! + +More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), +and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/tutorial/opt/opt/benchmark.sh b/examples/tutorial/opt/opt/benchmark.sh new file mode 100644 index 000000000..f02f7629a --- /dev/null +++ b/examples/tutorial/opt/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +done +done +done +done diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py new file mode 100644 index 000000000..833745f3e --- /dev/null +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -0,0 +1,6 @@ +from colossalai.zero.shard_utils import TensorShardStrategy + +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + tensor_placement_policy="auto", + reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py new file mode 100644 index 000000000..95f0abf1d --- /dev/null +++ b/examples/tutorial/opt/opt/context.py @@ -0,0 +1,32 @@ +import torch.distributed as dist + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class barrier_context(): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + parallel_mode (ParallelMode): the parallel mode corresponding to a process group + Usage: + with barrier_context(): + dataset = CIFAR10(root='./data', download=True) + """ + + def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): + # the class name is lowercase by convention + current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) + self.should_block = current_rank != executor_rank + self.group = gpc.get_group(parallel_mode=parallel_mode) + + def __enter__(self): + if self.should_block: + dist.barrier(group=self.group) + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.should_block: + dist.barrier(group=self.group) diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt new file mode 100644 index 000000000..c34df7992 --- /dev/null +++ b/examples/tutorial/opt/opt/requirements.txt @@ -0,0 +1,6 @@ +colossalai +torch >= 1.8.1 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +accelerate == 0.13.2 diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py new file mode 100755 index 000000000..c4f576cb1 --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import math +import os +import time +from itertools import chain + +import datasets +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from context import barrier_context +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import colossalai +import transformers +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoTokenizer, + GPT2Tokenizer, + OPTForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils.versions import require_version + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument("-s", "--synthetic", action="store_true") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data.") + parser.add_argument("--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.") + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument("--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)."), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets") + parser.add_argument("--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed."), + ) + + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + # Sanity checks + if not args.synthetic: + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +class DummyDataloader: + + def __init__(self, length, batch_size, seq_len, vocab_size): + self.length = length + self.batch_size = batch_size + self.seq_len = seq_len + self.vocab_size = vocab_size + + def generate(self): + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch(config=dict()) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # Handle the repository creation + with barrier_context(): + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + logger.info("Start preparing dataset", ranks=[0]) + if not args.synthetic: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + logger.info("Dataset is prepared", ranks=[0]) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.model_name_or_path == 'facebook/opt-13b': + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + else: + print(f'load model from {args.model_name_or_path}') + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + PLACEMENT_POLICY = 'auto' + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + + if not args.synthetic: + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size) + ] for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + if not args.synthetic: + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = get_dataloader(train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size) + else: + train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, + config.vocab_size) + logger.info("Dataloaders have been created", ranks=[0]) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) + num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size + num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {num_train_samples}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + completed_steps = 0 + starting_epoch = 0 + global_step = 0 + + for epoch in range(starting_epoch, args.num_train_epochs): + + if completed_steps >= args.max_train_steps: + break + + model.train() + for step, batch in enumerate(train_dataloader): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(use_cache=False, **batch) + loss = outputs['loss'] + optimizer.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + global_step += 1 + logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:num_eval_samples] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + + if args.output_dir is not None: + model_state = model.state_dict() + if is_main_process: + torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + dist.barrier() + # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + # model.load_state_dict(load_state, strict=False) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/opt/opt/run_clm.sh b/examples/tutorial/opt/opt/run_clm.sh new file mode 100644 index 000000000..858d3325a --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm.sh @@ -0,0 +1,22 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/opt/opt/run_clm_synthetic.sh b/examples/tutorial/opt/opt/run_clm_synthetic.sh new file mode 100644 index 000000000..80435f16c --- /dev/null +++ b/examples/tutorial/opt/opt/run_clm_synthetic.sh @@ -0,0 +1,21 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + -s \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/requirements.txt b/examples/tutorial/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/tutorial/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/sequence_parallel/README.md b/examples/tutorial/sequence_parallel/README.md new file mode 100644 index 000000000..7058f53db --- /dev/null +++ b/examples/tutorial/sequence_parallel/README.md @@ -0,0 +1,151 @@ +# Sequence Parallelism with BERT + +In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate +activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length. + +Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) + +## 🚀Quick Start +1. Run with the following command +```bash +export PYTHONPATH=$PWD +colossalai run --nproc_per_node 4 train.py -s +``` +2. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again. + + +## How to Prepare WikiPedia Dataset + +First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items: +1. raw WikiPedia dataset +2. wikipedia extractor (extract data from the raw dataset) +3. vocabulary file +4. preprocessing scripts (generate final data from extracted data) + +For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file. + +```python +# download raw data +mkdir data && cd ./data +wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 + +# install wiki extractor +git clone https://github.com/FrankLeeeee/wikiextractor.git +pip install ./wikiextractor + +# extractmodule +wikiextractor --json enwiki-latest-pages-articles.xml.bz2 +cat text/*/* > ./corpus.json +cd .. + +# download vocab file +mkdir vocab && cd ./vocab +wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt +cd .. + +# preprocess some data +git clone https://github.com/NVIDIA/Megatron-LM.git +cd ./Megatron-LM +python tools/preprocess_data.py \ + --input ../data/corpus.json \ + --output-prefix my-bert \ + --vocab ../vocab/bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences \ + --workers 24 +``` + +After running the preprocessing scripts, you will obtain two files: +1. my-bert_text_sentence.bin +2. my-bert_text_sentence.idx + +If you happen to encouter `index out of range` problem when running Megatron's script, +this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below: + +```python +class Encoder(object): + def __init__(self, args): + ... + + def initializer(self): + ... + + def encode(self, json_line): + data = json.loads(json_line) + ids = {} + for key in self.args.json_keys: + text = data[key] + doc_ids = [] + + # lsg: avoid sentences which start with a punctuation + # as it cannot be tokenized by splitter + if len(text) > 0 and text[0] in string.punctuation: + text = text[1:] + + for sentence in Encoder.splitter.tokenize(text): + sentence_ids = Encoder.tokenizer.tokenize(sentence) + if len(sentence_ids) > 0: + doc_ids.append(sentence_ids) + if len(doc_ids) > 0 and self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eod) + ids[key] = doc_ids + return ids, len(json_line) +``` + +## How to Train with Sequence Parallelism + +We provided `train.py` for you to execute training. Before invoking the script, there are several +steps to perform. + +### Step 1. Set data path and vocab path + +At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`. + +```python +DATA_PATH = +VOCAB_FILE_PATH = +``` + +`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension. + +For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set + +```python +DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence' +``` + +The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset +(e.g. bert-large-uncased-vocab.txt). + +### Step 3. Make Dataset Helper + +Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`. + +```python +cd ./data/datasets +make +``` + +### Step 3. Configure your parameters + +In the `config.py` provided, a set of parameters are defined including training scheme, model, etc. +You can also modify the ColossalAI setting. For example, if you wish to parallelize over the +sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=`. + +### Step 4. Invoke parallel training + +Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your +machine setting. + +- If you are using a single machine with multiple GPUs, PyTorch launch utility can easily let you + start your script. A sample command is like below: + + ```bash + colossalai run --nproc_per_node --master_addr localhost --master_port 29500 train.py + ``` + +- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai + launch_from_slurm` or `colossalai.launch_from_openmpi` as it is easier to use SLURM and OpenMPI + to start multiple processes over multiple nodes. If you have your own launcher, you can fall back + to the default `colossalai.launch` function. diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py new file mode 100644 index 000000000..df0c5282f --- /dev/null +++ b/examples/tutorial/sequence_parallel/config.py @@ -0,0 +1,38 @@ +from colossalai.amp import AMP_TYPE + +DATA_PATH = '' +VOCAB_FILE_PATH = '' + +# hyper-parameters +TRAIN_ITERS = 1000000 +DECAY_ITERS = 990000 +WARMUP_FRACTION = 0.01 +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +EVAL_ITERS = 10 +EVAL_INTERVAL = 10 +LR = 0.0001 +MIN_LR = 1e-05 +WEIGHT_DECAY = 0.01 +SEQ_LENGTH = 512 + +# BERT config +DEPTH = 12 +NUM_ATTENTION_HEADS = 12 +HIDDEN_SIZE = 768 + +# model config +ADD_BINARY_HEAD = False + +# random seed +SEED = 1234 + +# pipeline config +# only enabled when pipeline > 1 +NUM_MICRO_BATCHES = 4 + +# colossalai config +parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) + +fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) + +gradient_handler = [dict(type='SequenceParallelGradientHandler')] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py new file mode 100644 index 000000000..1ef2d9993 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -0,0 +1,102 @@ +from colossalai.context.parallel_context import ParallelContext +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.context import ParallelMode +from .datasets.data_samplers import build_pretraining_data_loader +from .datasets.builder import build_train_valid_test_datasets +import torch + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + + +def build_train_valid_test_data_iterators(train_iters, + global_batch_size, + eval_interval, + eval_iters, + dataloader_type='single', + **kwargs + ): + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + logger = get_dist_logger() + logger.info('> building train, validation, and test datasets ...', ranks=[0]) + + # Backward compatibility, assume fixed batch size. + # if iteration > 0 and consumed_train_samples == 0: + # assert train_samples is None, \ + # 'only backward compatibility support for iteration-based training' + # consumed_train_samples = iteration * global_batch_size + # if iteration > 0 and consumed_valid_samples == 0: + # if train_samples is None: + # consumed_valid_samples = (iteration // eval_interval) * \ + # eval_iters * global_batch_size + + # Data loader only on rank 0 of each model parallel group. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + + # Number of train/valid/test samples. + train_samples = train_iters * global_batch_size + eval_iters_ = (train_iters // eval_interval + 1) * eval_iters + test_iters = eval_iters + train_val_test_num_samples = [train_samples, + eval_iters_ * global_batch_size, + test_iters * global_batch_size] + logger.info(' > datasets target sizes (minimum size):') + logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + + # Build the datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + + # Build dataloaders. + dp_size = gpc.get_world_size(ParallelMode.DATA) + train_dataloader = build_pretraining_data_loader( + train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + valid_dataloader = build_pretraining_data_loader( + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and train_iters > 0 + do_valid = valid_dataloader is not None and eval_iters > 0 + do_test = test_dataloader is not None and eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(flags, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Build iterators. + dl_type = dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader)) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloader)) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py new file mode 100644 index 000000000..d092db3e7 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -0,0 +1,165 @@ +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +import torch + +_MAX_DATA_DIM = 5 + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data dictionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, + data) + + # Pack on rank zero. + if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # Check that all keys have the same data type. + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, + device=torch.cuda.current_device(), + dtype=datatype) + + # Broadcast + torch.distributed.broadcast(flatten_data, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output + + +def get_batch(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b['text'].long() + types = data_b['types'].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'].float() + lm_labels = data_b['labels'].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +def get_batch_for_sequence_parallel(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + # unpack + data_b = broadcast_data(keys, data, datatype) + + # # get tensor parallel local rank + global_rank = torch.distributed.get_rank() + local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + local_rank = global_rank % local_world_size + seq_length = data_b['text'].size(1) + sub_seq_length = seq_length // local_world_size + sub_seq_start = local_rank * sub_seq_length + sub_seq_end = (local_rank+1) * sub_seq_length + # + # # Unpack. + tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() + types = data_b['types'][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b['is_random'].long() + loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b['padding_mask'].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + +class SequenceParallelDataIterator: + + def __init__(self, data_iter): + self.data_iter = data_iter + + + def __iter__(self): + return self.data_iter + + def __next__(self): + return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/datasets/Makefile b/examples/tutorial/sequence_parallel/data/datasets/Makefile new file mode 100644 index 000000000..8f9db7686 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/examples/tutorial/sequence_parallel/data/datasets/__init__.py b/examples/tutorial/sequence_parallel/data/datasets/__init__.py new file mode 100644 index 000000000..cd5f898c6 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py new file mode 100644 index 000000000..d6388bd9f --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT Style dataset.""" + +import os +import time + +import numpy as np +import torch +from torch.utils.data import Dataset + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger + +from ..tokenizer import get_tokenizer +from .dataset_utils import ( + create_masked_lm_predictions, + create_tokens_and_tokentypes, + get_a_and_b_segments, + pad_and_convert_to_numpy, + truncate_segments, +) + +try: + from . import helpers +except: + print("helper is not built, ignore this message if you are using synthetic data.") + + +class BertDataset(Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, + short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping_( + self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens, + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample( + sample, + seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, + self.sep_id, + self.mask_id, + self.pad_id, + self.masked_lm_prob, + np_rng, + self.binary_head) + + +def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, + seed, name, binary_head): + logger = get_dist_logger() + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + # First compile and then import. + samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, + max_num_samples, max_seq_length, short_seq_prob, seed, verbose, + 2 if binary_head else 1) + logger.info('\n > done building samples index maping', ranks=[0]) + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + # Make sure all the ranks have built the mapping + logger.info('\n > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format(time.time() - start_time), + ranks=[0]) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) + if gpc.is_initialized(ParallelMode.PIPELINE): + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) + assert counts[0].item() == (torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + + # Load indexed dataset. + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + + '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + + '\n total number of samples: {}'.format(samples_mapping.shape[0]), + ranks=[0]) + + return samples_mapping + + +def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, + sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, + _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, + mask_id, max_predictions_per_seq, np_rng) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ + = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated) + } + return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py new file mode 100644 index 000000000..6a06c869d --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + + +class BlendableDataset(torch.utils.data.Dataset): + + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indices. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from . import helpers + helpers.build_blending_indices(self.dataset_index, + self.dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py new file mode 100644 index 000000000..6106f833b --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -0,0 +1,152 @@ +from .blendable_dataset import BlendableDataset +from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ +from .bert_dataset import BertDataset +from colossalai.logging import get_dist_logger + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + logger = get_dist_logger() + + # Print stats about the splits. + logger.info('\n > dataset split:', ranks=[0]) + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type != DSET_TYPE_BERT: + raise NotImplementedError("Only BERT dataset is supported") + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py new file mode 100644 index 000000000..cf547ad97 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -0,0 +1,153 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataloaders.""" + +import torch +import random +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): + """Build dataloader given an input dataset.""" + + if dataset is None: + return None + + # Megatron sampler + if dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + elif dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + else: + raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) + + +class MegatronPretrainingSampler: + + def __init__(self, + total_samples, + consumed_samples, + micro_batch_size, + data_parallel_rank, + data_parallel_size, + drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py new file mode 100644 index 000000000..cf4e4763f --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -0,0 +1,592 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import time +import collections +from colossalai.logging import get_dist_logger +import numpy as np +from .blendable_dataset import BlendableDataset +from .indexed_dataset import make_dataset as make_indexed_dataset + +DSET_TYPE_STD = 'standard_bert' +DSET_TYPE_ICT = 'ict' + +DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] + + +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0]*num_datasets + prefixes = [0]*num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2*i]) + prefixes[i] = (data_prefix[2*i+1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples]) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + #print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions(tokens, + vocab_id_list, vocab_id_to_token_dict, + masked_lm_prob, + cls_id, sep_id, mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (do_whole_word_mask and len(cand_indexes) >= 1 and + not is_start_piece(vocab_id_to_token_dict[token])): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, + masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + # Note(mingdachen): + # By default, we set the probabilities to favor shorter ngram sequences. + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + pvals = 1. / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx:idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + masked_lms = [] + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + n = np_rng.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + assert len(masked_lms) <= num_to_predict + + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, skip_warmup, + binary_head, + dataset_type='standard_bert'): + logger = get_dist_logger() + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + title_dataset = get_indexed_dataset_(args.titles_data_path, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is designed to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + logger.info('\n > dataset split:') + + def print_split_stats(name, index): + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + logger.info('\n {}:'.format(name) + + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], + splits[index + 1], + splits[index + 1] - splits[index]) + + '\n sentence indices in [{}, {}) total of {} sentences'.format( + start_index, + end_index, + end_index - start_index), + ranks=[0]) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + from .bert_dataset import BertDataset + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + binary_head=binary_head + ) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + **kwargs + ) + else: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + **kwargs + ) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + logger = get_dist_logger() + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + logger.info('\n > building dataset index ...', ranks=[0]) + logger.info('\n > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time), ranks=[0]) + logger.info('\n > indexed dataset stats:' + + '\n number of documents: {}'.format( + indexed_dataset.doc_idx.shape[0] - 1) + + '\n number of sentences: {}'.format( + indexed_dataset.sizes.shape[0]), + ranks=[0] + ) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp new file mode 100644 index 000000000..e45926a97 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -0,0 +1,717 @@ +/* + coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py new file mode 100644 index 000000000..6dac35ff9 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -0,0 +1,156 @@ +import itertools +import random + +import numpy as np +from torch.utils.data import Dataset + +from megatron import get_tokenizer +from megatron import get_args +from megatron.data.dataset_utils import get_indexed_dataset_ +from megatron.data.realm_dataset_utils import get_block_samples_mapping + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_ict_dataset(use_titles=True, query_in_block_prob=1): + """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) + rather than for training, since it is only built with a single epoch sample mapping. + """ + args = get_args() + block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + + kwargs = dict( + name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs + ) + dataset = ICTDataset(**kwargs) + return dataset + + +class ICTDataset(Dataset): + """Dataset containing sentences and their blocks for an inverse cloze task.""" + def __init__(self, name, block_dataset, title_dataset, data_prefix, + num_epochs, max_num_samples, max_seq_length, query_in_block_prob, + seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.query_in_block_prob = query_in_block_prob + self.block_dataset = block_dataset + self.title_dataset = title_dataset + self.rng = random.Random(self.seed) + self.use_titles = use_titles + self.use_one_sent_docs = use_one_sent_docs + + self.samples_mapping = get_block_samples_mapping( + block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.tokenizer = get_tokenizer() + self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_list = self.tokenizer.inv_vocab + self.cls_id = self.tokenizer.cls + self.sep_id = self.tokenizer.sep + self.mask_id = self.tokenizer.mask + self.pad_id = self.tokenizer.pad + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" + sample_data = self.samples_mapping[idx] + start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() + + if self.use_titles: + title = self.title_dataset[int(doc_idx)] + title_pad_offset = 3 + len(title) + else: + title = None + title_pad_offset = 2 + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 + + # randint() is inclusive for Python rng + rand_sent_idx = self.rng.randint(0, len(block) - 1) + + # keep the query in the context query_in_block_prob fraction of the time. + if self.rng.random() < self.query_in_block_prob: + query = block[rand_sent_idx].copy() + else: + query = block.pop(rand_sent_idx) + + # still need to truncate because blocks are concluded when + # the sentence lengths have exceeded max_seq_length. + query = query[:self.max_seq_length - 2] + block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + + query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) + context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) + + query_mask = make_attention_mask(query_tokens, query_tokens) + context_mask = make_attention_mask(context_tokens, context_tokens) + + block_data = sample_data.as_array() + + sample = { + 'query_tokens': query_tokens, + 'query_mask': query_mask, + 'query_pad_mask': query_pad_mask, + 'context_tokens': context_tokens, + 'context_mask': context_mask, + 'context_pad_mask': context_pad_mask, + 'block_data': block_data, + } + + return sample + + def get_block(self, start_idx, end_idx, doc_idx): + """Get the IDs for an evidence block plus the title of the corresponding document""" + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + title = self.title_dataset[int(doc_idx)] + + block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def get_null_block(self): + """Get empty block and title - used in REALM pretraining""" + block, title = [], [] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def concat_and_pad_tokens(self, tokens, title=None): + """Concat with special tokens and pad sequence to self.max_seq_length""" + tokens = list(tokens) + if title is None: + tokens = [self.cls_id] + tokens + [self.sep_id] + else: + title = list(title) + tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] + assert len(tokens) <= self.max_seq_length + + num_pad = self.max_seq_length - len(tokens) + pad_mask = [1] * len(tokens) + [0] * num_pad + tokens += [self.pad_id] * num_pad + + return np.array(tokens), np.array(pad_mask) diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py new file mode 100644 index 000000000..b4febcd82 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16 +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8 + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' len(ds.doc_idx) - 1: + args.count = len(ds.doc_idx) - 1 + + for i in range(args.count): + start = ds.doc_idx[i] + end = ds.doc_idx[i + 1] + ids = ds[start:end] + print(f"Document {i}:") + print("--------------") + for s in ids: + assert len(s) > 0 + l = s.data.tolist() + text = tokenizer.detokenize(l) + print(text) + print("---") + + +def test_indexed_dataset_get(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + size = ds.sizes[0] + print(f"size: {size}") + full = ds.get(0) + print(full) + # print(tokenizer.detokenize(full.data.tolist())) + print("---") + end = ds.get(0, offset=size - 10) + print(end) + # print(tokenizer.detokenize(end.data.tolist())) + + start = ds.get(0, length=10) + print(start) + # print(tokenizer.detokenize(start.data.tolist())) + + part = ds.get(0, offset=2, length=8) + print(part) + # print(tokenizer.detokenize(part.data.tolist())) + +# def test_albert_dataset(args): +# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) +# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) +# # ds = AlbertDataset(idataset, tokenizer) +# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, +# args.epochs, args.max_num_samples, +# args.masked_lm_prob, args.seq_length, +# args.short_seq_prob, args.seed) +# truncated = 0 +# total = 0 +# for i, s in enumerate(ds): +# ids = s['text'] +# tokens = ds.tokenizer.convert_ids_to_tokens(ids) +# print(tokens) +# if i >= args.count-1: +# exit() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, help='prefix to data files') + parser.add_argument('--dataset-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer']) + parser.add_argument('--count', type=int, default=10, + help='Number of samples/documents to print') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + + parser.add_argument('--epochs', type=int, default=5, + help='Number of epochs to plan for') + parser.add_argument('--max-num-samples', type=int, default=None, + help='Maximum number of samples to plan for') + parser.add_argument('--masked-lm-prob', type=float, default=0.15, + help='probability of masking tokens') + parser.add_argument('--seq-length', type=int, default=512, + help='maximum sequence length') + parser.add_argument('--short-seq-prob', type=float, default=0.1, + help='probability of creating a short sequence') + parser.add_argument('--seed', type=int, default=1234, + help='random seed') + args = parser.parse_args() + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + + if args.dataset_impl == "infer": + args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) + +# test_albert_dataset(args) + test_indexed_dataset_get(args) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh new file mode 100755 index 000000000..d121c8595 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/datasets/test/test_preprocess_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMPL=cached +python ../preprocess_data.py \ + --input test_samples.json \ + --vocab vocab.txt \ + --dataset-impl ${IMPL} \ + --output-prefix test_samples_${IMPL} \ + --workers 1 \ + --log-interval 2 diff --git a/examples/tutorial/sequence_parallel/data/dummy_dataloader.py b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py new file mode 100644 index 000000000..faa90175c --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py @@ -0,0 +1,39 @@ +import torch + + +class DummyDataloader(): + + def __init__(self, batch_size, vocab_size, seq_length): + self.batch_size = batch_size + self.vocab_size = vocab_size + self.seq_length = seq_length + self.step = 0 + + def generate(self): + tokens = torch.randint(low=0, high=self.vocab_size, size=( + self.batch_size, + self.seq_length, + )) + types = torch.randint(low=0, high=3, size=( + self.batch_size, + self.seq_length, + )) + sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) + loss_mask = torch.randint(low=0, high=2, size=( + self.batch_size, + self.seq_length, + )) + lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) + padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) + return dict(text=tokens, + types=types, + is_random=sentence_order, + loss_mask=loss_mask, + labels=lm_labels, + padding_mask=padding_mask) + + def __iter__(self): + return self + + def __next__(self): + return self.generate() \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py new file mode 100644 index 000000000..df27f1424 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/__init__.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .tokenizer import build_tokenizer + + +_TOKENIZER = None +_PADDED_VOCAB_SIZE = -1 + + +def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids) + global _TOKENIZER, _PADDED_VOCAB_SIZE + _TOKENIZER = tokenizer + _PADDED_VOCAB_SIZE = padded_vocab_size + + +def get_tokenizer(): + global _TOKENIZER + return _TOKENIZER + + +def get_padded_vocab_size(): + global _PADDED_VOCAB_SIZE + return _PADDED_VOCAB_SIZE diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py new file mode 100644 index 000000000..1be494793 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/bert_tokenization.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenization.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abbreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py new file mode 100644 index 000000000..ee3c923e8 --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + +from .bert_tokenization import FullTokenizer as FullBertTokenizer + + +def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): + """Initialize tokenizer.""" + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print('> building {} tokenizer ...'.format(tokenizer_type), + flush=True) + + # Select and instantiate the tokenizer. + if tokenizer_type == 'BertWordPieceLowerCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=True, + vocab_extra_ids=vocab_extra_ids) + elif tokenizer_type == 'BertWordPieceCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, + lower_case=False, + vocab_extra_ids=vocab_extra_ids) + else: + raise NotImplementedError('{} tokenizer is not ' + 'implemented.'.format(tokenizer_type)) + + # Add vocab size. + padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) + + return tokenizer, padded_vocab_size + + +def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + + if gpc.is_initialized(ParallelMode.TENSOR): + multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR) + else: + multiple = make_vocab_size_divisible_by + while (after % multiple) != 0: + after += 1 + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + print(' > padded vocab (size: {}) with {} dummy tokens ' + '(new size: {})'.format( + orig_vocab_size, after - orig_vocab_size, after), flush=True) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError('detokenizer is not implemented for {} ' + 'tokenizer'.format(self.name)) + + @property + def cls(self): + raise NotImplementedError('CLS is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def sep(self): + raise NotImplementedError('SEP is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def pad(self): + raise NotImplementedError('PAD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def eod(self): + raise NotImplementedError('EOD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def mask(self): + raise NotImplementedError('MASK is not provided for {} ' + 'tokenizer'.format(self.name)) + + +class _BertWordPieceTokenizer(AbstractTokenizer): + """Original BERT wordpiece tokenizer.""" + + def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + if lower_case: + name = 'BERT Lower Case' + else: + name = 'BERT Upper Case' + super().__init__(name) + self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + self.mask_id = self.tokenizer.vocab['[MASK]'] + self._additional_special_tokens = [] + + # (dsachan) Add BOS and EOS tokens + SPECIAL_TOKENS = {'eos_token': '[EOS]', + 'bos_token': '[BOS]'} + self._bos_token = '[BOS]' + self.add_token(self._bos_token) + self._bos_token_id = self.vocab.get(self._bos_token) + + self._eos_token = '[EOS]' + self.add_token(self._eos_token) + self._eos_token_id = self.vocab.get(self._eos_token) + + # (dsachan) Add additional special tokens + # These can be used as sentinel tokens in T5 model inputs + additional_special_tokens = [] + additional_special_tokens.extend( + ["".format(i) for i in range(vocab_extra_ids)]) + self.add_additional_special_tokens(additional_special_tokens) + + def add_token(self, token): + if token not in self.vocab: + self.inv_vocab[self.vocab_size] = token + # self.vocab_size comes from len(vocab) + # and it will increase as we add elements + self.vocab[token] = self.vocab_size + + def add_additional_special_tokens(self, tokens_list): + setattr(self, "additional_special_tokens", tokens_list) + for value in tokens_list: + self.add_token(value) + + @property + def vocab_size(self): + return self.tokenizer.vocab_size() + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self.tokenizer.inv_vocab + + def tokenize(self, text): + text_tokens = self.tokenizer.tokenize(text) + return self.tokenizer.convert_tokens_to_ids(text_tokens) + + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode_token_ids(self, token_ids): + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + exclude_list = ['[PAD]', '[CLS]'] + non_pads = [t for t in tokens if t not in exclude_list] + + result = "" + for s in non_pads: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + + return result + + @property + def cls(self): + return self.cls_id + + @property + def sep(self): + return self.sep_id + + @property + def pad(self): + return self.pad_id + + @property + def mask(self): + return self.mask_id + + @property + def bos_token(self): + """ Beginning of sentence token id """ + return self._bos_token + + @property + def eos_token(self): + """ End of sentence token id """ + return self._eos_token + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self._additional_special_tokens + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self._bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self._eos_token_id + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + return [self.vocab.get(token) for token in self._additional_special_tokens] + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value diff --git a/examples/tutorial/sequence_parallel/loss_func/__init__.py b/examples/tutorial/sequence_parallel/loss_func/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py new file mode 100644 index 000000000..e87a778cf --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.logging import get_dist_logger +import torch.nn.functional as F +import torch.distributed as dist +from .cross_entropy import vocab_cross_entropy + + +class BertLoss(nn.Module): + + def forward(self, + lm_loss, + sop_logits, + loss_mask, + sentence_order): + lm_loss_ = lm_loss.float() + loss_mask = loss_mask.float() + loss_mask_sum = loss_mask.sum() + lm_loss = torch.sum( + lm_loss_.view(-1) * loss_mask.reshape(-1)) + + lm_loss /= loss_mask_sum + + torch.distributed.all_reduce( + lm_loss, + group=gpc.get_group(ParallelMode.SEQUENCE) + ) + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), + sentence_order.view(-1), + ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) + else: + sop_loss = None + loss = lm_loss + + return loss diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py new file mode 100644 index 000000000..54553c29a --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -0,0 +1,75 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _VocabCrossEntropy(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, vocab_parallel_logits, target): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = target < 0 + masked_target = target.clone() + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= ( + 1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_cross_entropy(vocab_logits, target): + """helper function for the cross entropy.""" + + return _VocabCrossEntropy.apply(vocab_logits, target) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py new file mode 100644 index 000000000..a3d92f294 --- /dev/null +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -0,0 +1,55 @@ + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py new file mode 100644 index 000000000..2b8b615bc --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/__init__.py @@ -0,0 +1 @@ +from .annealing_lr import AnnealingLR diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py new file mode 100644 index 000000000..8d95679ff --- /dev/null +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate decay functions.""" + +import math + + +class AnnealingLR(object): + """Anneals the learning rate.""" + + def __init__(self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False): + + # Class values. + self.optimizer = optimizer + + self.max_lr = float(max_lr) + self.min_lr = min_lr + assert self.min_lr >= 0.0 + assert self.max_lr >= self.min_lr + + self.warmup_steps = warmup_steps + self.num_steps = 0 + self.decay_steps = decay_steps + assert self.decay_steps > 0 + assert self.warmup_steps < self.decay_steps + + self.decay_style = decay_style + + self.override_lr_scheduler = override_lr_scheduler + self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler + if self.override_lr_scheduler: + assert not self.use_checkpoint_lr_scheduler, 'both override and '\ + 'use-checkpoint are set.' + + # Set the learning rate + self.step(0) + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + # Use linear warmup for the initial part. + if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: + return self.max_lr * float(self.num_steps) / \ + float(self.warmup_steps) + + # If the learning rate is constant, just return the initial value. + if self.decay_style == 'constant': + return self.max_lr + + # For any steps larger than `self.decay_steps`, use `self.min_lr`. + if self.num_steps > self.decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = self.num_steps - self.warmup_steps + decay_steps_ = self.decay_steps - self.warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = self.max_lr - self.min_lr + + if self.decay_style == 'linear': + coeff = (1.0 - decay_ratio) + elif self.decay_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + else: + raise Exception('{} decay style is not supported.'.format( + self.decay_style)) + + return self.min_lr + coeff * delta_lr + + def step(self, increment=1): + """Set lr for all parameters groups.""" + self.num_steps += increment + new_lr = self.get_lr() + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + state_dict = { + 'max_lr': self.max_lr, + 'warmup_steps': self.warmup_steps, + 'num_steps': self.num_steps, + 'decay_style': self.decay_style, + 'decay_steps': self.decay_steps, + 'min_lr': self.min_lr + } + return state_dict + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_lr_scheduler: + return cls_value + + if not self.use_checkpoint_lr_scheduler: + assert cls_value == sd_value, \ + f'AnnealingLR: class input value {cls_value} and checkpoint' \ + f'value {sd_value} for {name} do not match' + return sd_value + + def load_state_dict(self, sd): + + if 'start_lr' in sd: + max_lr_ = sd['start_lr'] + else: + max_lr_ = sd['max_lr'] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, + 'learning rate') + + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], + 'minimum learning rate') + + if 'warmup_iter' in sd: + warmup_steps_ = sd['warmup_iter'] + else: + warmup_steps_ = sd['warmup_steps'] + self.warmup_steps = self._check_and_set(self.warmup_steps, + warmup_steps_, + 'warmup iterations') + + if 'end_iter' in sd: + decay_steps_ = sd['end_iter'] + else: + decay_steps_ = sd['decay_steps'] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, + 'total number of iterations') + self.decay_style = self._check_and_set(self.decay_style, + sd['decay_style'], + 'decay style') + + if 'num_iters' in sd: + num_steps = sd['num_iters'] + else: + num_steps = sd['num_steps'] + self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py new file mode 100644 index 000000000..139597f9c --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py new file mode 100644 index 000000000..049579c5a --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -0,0 +1,282 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +import inspect +from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding +from .layers.init_method import init_normal, output_init_normal +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.utils import partition_uniform + + +class BertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + ): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.preprocessor = PreProcessor(self.sub_seq_length) + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + self.bert_layers = nn.ModuleList() + + for i in range(num_layers): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0), + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + hidden_states = hidden_states.transpose(0, 1).contiguous() + + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return self.head(output, self.embedding.word_embedding_weight, lm_labels) + + +class PipelineBertForPretrain(nn.Module): + + def __init__(self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None): + super().__init__() + self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) + assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + self.sub_seq_length = max_sequence_length // self.seq_parallel_size + self.init_std = init_std + self.num_layers = num_layers + + if not add_binary_head: + num_tokentypes = 0 + + self.first_stage = first_stage + self.last_stage = last_stage + + self.preprocessor = PreProcessor(self.sub_seq_length) + + if self.first_stage: + self.embedding = Embedding(hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes) + + # transformer layers + self.bert_layers = nn.ModuleList() + + if start_idx is None and end_idx is None: + start_idx = 0 + end_idx = num_layers + + for i in range(start_idx, end_idx): + bert_layer = BertLayer(layer_number=i+1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16 + ) + self.bert_layers.append(bert_layer) + + if self.last_stage: + self.word_embeddings = VocabEmbedding(vocab_size, hidden_size) + self.layer_norm = LayerNorm(hidden_size) + self.head = BertDualHead(hidden_size, vocab_size, + add_binary_head=add_binary_head) + self.reset_parameters() + + def _init_normal(self, tensor): + init_normal(tensor, sigma=self.init_std) + + def _output_init_normal(self, tensor): + output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers) + + def reset_parameters(self): + # initialize embedding + if self.first_stage: + self._init_normal(self.embedding.word_embedding_weight) + self._init_normal(self.embedding.position_embeddings.weight) + if self.embedding.tokentype_embeddings: + self._init_normal(self.embedding.tokentype_embeddings.weight) + + # initialize bert layer + for layer in self.bert_layers: + # initialize self attention + self._init_normal(layer.self_attention.query_key_value.weight) + self._output_init_normal(layer.self_attention.dense.weight) + self._init_normal(layer.mlp.dense_h_to_4h.weight) + self._output_init_normal(layer.mlp.dense_4h_to_h.weight) + + # initializer head + if self.last_stage: + self._init_normal(self.head.lm_head.dense.weight) + if self.head.binary_head is not None: + self._init_normal(self.head.binary_head.pooler.dense.weight) + self._init_normal(self.head.binary_head.dense.weight) + + def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): + # inputs of the forward function + # input_ids: [batch_size, sub_seq_len] + # attention_mask: [batch_size, seq_len] + # tokentype_ids: [batch_size, sub_seq_len] + # outputs of preprocessor + # pos_ids: [batch_size, sub_seq_len] + # attention_masks: [batch_size, 1, sub_seq_len, seq_len] + if self.first_stage: + pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks) + else: + _, attention_masks = self.preprocessor(None, attention_masks) + + if self.first_stage: + hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids) + hidden_states = hidden_states.transpose(0, 1).contiguous() + else: + hidden_states = input_ids + + # hidden_states shape change: + # [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size] + for idx, layer in enumerate(self.bert_layers): + hidden_states = layer(hidden_states, attention_masks) + + if self.last_stage: + hidden_states = hidden_states.transpose(0, 1).contiguous() + output = self.layer_norm(hidden_states) + output = self.head(output, self.word_embeddings.weight, lm_labels) + else: + output = hidden_states + + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # word_embedding: [vocab_size, hidden_size] + return output + + +def _filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] + models = [] + for start, end in parts: + kwargs['num_layers'] = num_layers + kwargs['start_idx'] = start + kwargs['end_idx'] = end + kwargs['first_stage'] = start == 0 + kwargs['last_stage'] = end == num_layers + logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) + if start == 0: + wrapper.register_module(chunk.embedding.word_embeddings) + elif end == num_layers: + wrapper.register_module(chunk.word_embeddings) + models.append(chunk) + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py new file mode 100644 index 000000000..3a8823caa --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -0,0 +1,4 @@ +from .embedding import VocabEmbedding, Embedding +from .bert_layer import BertLayer +from .head import BertDualHead +from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py new file mode 100644 index 000000000..4ede21516 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing +from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from colossalai.kernel.cuda_native import LayerNorm +from .mlp import TransformerMLP +from .dropout import get_bias_dropout_add + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class BertLayer(nn.Module): + """A single transformer layer. + Transformer layer takes input with size [b, s, h] and returns an + output of the same size. + """ + + def __init__(self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False): + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.fp32_residual_connection = fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size) + + # Self attention. + self.self_attention = TransformerSelfAttentionRing( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + attention_mask_func=attention_mask_func, + layer_number=layer_number, + apply_query_key_layer_scaling=True, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + fp16=is_naive_fp16 + ) + + self.hidden_dropout = hidden_dropout + self.bias_dropout_fusion = bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm(hidden_size) + + self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [batch_size, sub_seq_len, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias.expand_as(residual), + residual, + self.hidden_dropout) + + return output diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py new file mode 100644 index 000000000..0e99105b8 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -0,0 +1,13 @@ +import torch + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add \ No newline at end of file diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py new file mode 100644 index 000000000..0700d960d --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +class VocabEmbedding(torch.nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super(VocabEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + + # Allocate weights and initialize. + self.weight = nn.Parameter(torch.empty( + self.num_embeddings, self.embedding_dim)) + init.xavier_uniform_(self.weight) + + def forward(self, hidden_state): + output = F.embedding(hidden_state, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + return output + + def __repr__(self): + return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ + f'embedding_dim={self.embedding_dim})' + + +class Embedding(nn.Module): + """Language model embeddings. + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + num_tokentypes): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.num_tokentypes = num_tokentypes + + self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + else: + self.tokentype_embeddings = None + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + @property + def word_embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + if tokentype_ids is not None and self.tokentype_embeddings is not None: + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py new file mode 100644 index 000000000..ea336b9d1 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -0,0 +1,78 @@ +import colossalai +import torch +import torch.nn as nn +import torch.nn.functional as F +from .pooler import Pooler +from .linear import Linear +from .embedding import VocabEmbedding +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.kernel import LayerNorm +from loss_func.cross_entropy import vocab_cross_entropy + + +class BertLMHead(nn.Module): + """Masked LM head for Bert + Arguments: + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + """ + + def __init__(self, + vocab_size, + hidden_size, + ): + + super(BertLMHead, self).__init__() + self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) + + self.dense = Linear(hidden_size, hidden_size) + self.layernorm = LayerNorm(hidden_size) + self.gelu = torch.nn.functional.gelu + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + + output = F.linear(hidden_states, word_embeddings_weight, self.bias) + lm_loss = vocab_cross_entropy(output, lm_labels) + + return lm_loss + + +class BertBinaryHead(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.pooler = Pooler(hidden_size) + self.dense = Linear(hidden_size, 2) + + def forward(self, hidden_states): + if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0: + output = self.pooler(hidden_states) + output = self.dense(output) + else: + output = None + return output + + +class BertDualHead(nn.Module): + + def __init__(self, hidden_size, vocab_size, add_binary_head): + super().__init__() + self.lm_head = BertLMHead(vocab_size, hidden_size) + self.add_binary_head = add_binary_head + if add_binary_head: + self.binary_head = BertBinaryHead(hidden_size) + else: + self.binary_head = None + + def forward(self, hidden_states, word_embeddings_weight, lm_labels): + if self.add_binary_head: + binary_output = self.binary_head(hidden_states) + else: + binary_output = None + lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels) + return lm_loss, binary_output diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py new file mode 100644 index 000000000..1b409dfe4 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -0,0 +1,12 @@ +import torch +import math + +def init_normal(tensor, sigma): + """Init method based on N(0, sigma).""" + torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + +def output_init_normal(tensor, sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + torch.nn.init.normal_(tensor, mean=0.0, std=std) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py new file mode 100644 index 000000000..5ae7d671e --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +import torch.nn.init as init + + +class Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + skip_bias_add=False): + super(Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty(self.output_size, + self.input_size, + )) + init.normal_(self.weight) + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py new file mode 100644 index 000000000..a255de813 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .linear import Linear +from colossalai.kernel.jit import bias_gelu_impl + + +class TransformerMLP(nn.Module): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + """ + + def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): + super(TransformerMLP, self).__init__() + + # Project to 4h. + self.dense_h_to_4h = Linear( + hidden_size, + int(hidden_size*mlp_ratio), + skip_bias_add=True) + + self.bias_gelu_fusion = fuse_gelu + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = Linear( + int(hidden_size*mlp_ratio), + hidden_size, + skip_bias_add=True) + + def forward(self, hidden_states): + # hidden states should be in the shape of [s, b, h] + # it will be projects into [s, b, 4h] + # and projected back to [s, b, h] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py new file mode 100644 index 000000000..282ed1147 --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +from .linear import Linear + + +class Pooler(nn.Module): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size): + super(Pooler, self).__init__() + self.dense = Linear(hidden_size, hidden_size) + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [b, s, h] + # sequence_index: index of the token to pool. + pooled = hidden_states[:, sequence_index, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py new file mode 100644 index 000000000..53a326dda --- /dev/null +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -0,0 +1,58 @@ +from colossalai.context.parallel_mode import ParallelMode +import torch +import torch.nn as nn +from colossalai.core import global_context as gpc + + +class PreProcessor(nn.Module): + + def __init__(self, sub_seq_length): + super().__init__() + self.sub_seq_length = sub_seq_length + + def bert_position_ids(self, token_ids): + # Create position ids + seq_length = token_ids.size(1) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + position_ids = torch.arange(seq_length*local_rank, + seq_length * (local_rank+1), + dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + def bert_extended_attention_mask(self, attention_mask): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + start_index = local_rank * self.sub_seq_length + end_index = (local_rank + 1) * self.sub_seq_length + + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s/D, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + + attention_mask_bss = attention_mask_bss[:, start_index:end_index, :] + + # [b, 1, s/D, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + + def forward(self, input_ids=None, attention_mask=None): + if attention_mask is not None: + extended_attention_mask = self.bert_extended_attention_mask(attention_mask) + else: + extended_attention_mask = None + + if input_ids is not None: + position_ids = self.bert_position_ids(input_ids) + else: + position_ids = None + return position_ids, extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/tutorial/sequence_parallel/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py new file mode 100644 index 000000000..b92061000 --- /dev/null +++ b/examples/tutorial/sequence_parallel/train.py @@ -0,0 +1,240 @@ +import argparse + +import torch +from data import build_train_valid_test_data_iterators +from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel +from data.tokenizer import get_padded_vocab_size, initialize_tokenizer +from loss_func.bert_loss import BertLoss +from lr_scheduler import AnnealingLR +from model.bert import BertForPretrain, build_pipeline_bert + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import PipelineSchedule +from colossalai.kernel import LayerNorm +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import FusedAdam +from colossalai.utils import MultiTimer, is_using_pp + + +def process_batch_data(batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + else: + data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels) + label = dict(loss_mask=loss_mask, sentence_order=sentence_order) + return data, label + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + return parser.parse_args() + + +def pipeline_data_process_func(stage_output, micro_batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = (tokens, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + else: + data = (stage_output, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + return data, label + + +def main(): + # initialize + args = parse_args() + colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + + logger = get_dist_logger() + + # build dataloader + if not args.synthetic: + initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase') + VOCAB_SIZE = get_padded_vocab_size() + trainloader, validloader, testloader = build_train_valid_test_data_iterators( + train_iters=gpc.config.TRAIN_ITERS, + global_batch_size=gpc.config.GLOBAL_BATCH_SIZE, + eval_interval=gpc.config.EVAL_INTERVAL, + eval_iters=gpc.config.EVAL_ITERS, + data_prefix=[gpc.config.DATA_PATH], + data_impl='mmap', + splits_string='949,50,1', + max_seq_length=gpc.config.SEQ_LENGTH, + masked_lm_prob=0.15, + short_seq_prob=0.1, + seed=1234, + skip_warmup=True, + binary_head=False, + ) + else: + from data.dummy_dataloader import DummyDataloader + + BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) + VOCAB_SIZE = 30528 + trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, + vocab_size=VOCAB_SIZE, + seq_length=gpc.config.SEQ_LENGTH) + + logger.info("Dataloaders are built", ranks=[0]) + + # build model + if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + is_naive_fp16 = True + else: + is_naive_fp16 = False + + use_pipeline = is_using_pp() + kwargs = dict(vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD) + + if use_pipeline: + model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) + else: + model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs) + + model = model.half() + model.reset_parameters() + logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0]) + + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + logger.info(f"This model has {total_numel} parameters") + + # build criterion + criterion = BertLoss() + logger.info("Criterion is built", ranks=[0]) + + # layernorm and bias has no weight decay + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in model.modules(): + if isinstance(module_, LayerNorm): + no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + else: + weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) + no_weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + + logger.info( + f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" + ) + # optimizer + optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), + lr=gpc.config.LR, + weight_decay=gpc.config.WEIGHT_DECAY) + logger.info("Optimizer is built", ranks=[0]) + + # lr scheduler + # follow Megatron-LM setting + warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) + lr_scheduler = AnnealingLR(optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style='linear') + logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") + + # # init + engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True) + + # build timer + timer = MultiTimer() + skip_iters = 0 + + # build loss tracker + accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() + accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda() + + # build data iters for pipeline parallel + if use_pipeline: + train_data_iter = SequenceParallelDataIterator(trainloader) + valid_data_iter = SequenceParallelDataIterator(validloader) + engine.schedule.data_process_func = pipeline_data_process_func + + logger.info("start training") + + for step in range(1, gpc.config.TRAIN_ITERS + 1): + timer.start('train-iterations') + engine.train() + if use_pipeline: + engine.zero_grad() + _, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False) + engine.step() + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + trainloader) + engine.zero_grad() + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + engine.backward(train_loss) + engine.step() + timer.stop('train-iterations', keep_in_history=True) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_train_loss += train_loss + + lr_scheduler.step() + + if step % gpc.config.EVAL_INTERVAL == 0: + engine.eval() + + for j in range(gpc.config.EVAL_ITERS): + with torch.no_grad(): + if use_pipeline: + _, _, eval_loss = engine.execute_schedule(valid_data_iter, + forward_only=True, + return_output_label=False) + else: + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( + validloader) + lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) + eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss += eval_loss + + if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): + accumulated_eval_loss /= gpc.config.EVAL_ITERS + accumulated_train_loss /= gpc.config.EVAL_INTERVAL + + timer_string = [] + for n, t in timer: + timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") + timer_string = ' | '.join(timer_string) + lr = list(engine.optimizer.param_groups)[0]['lr'] + loss_scale = engine.optimizer.optim.loss_scale.item() + + if gpc.is_initialized(ParallelMode.PIPELINE): + ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] + else: + ranks = [0] + logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + + f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + + f"| Learning rate: {lr} | " + timer_string, + ranks=ranks) + + for n, t in timer: + t.reset() + accumulated_eval_loss.zero_() + accumulated_train_loss.zero_() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/stable_diffusion/LICENSE b/examples/tutorial/stable_diffusion/LICENSE new file mode 100644 index 000000000..0e609df0d --- /dev/null +++ b/examples/tutorial/stable_diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/examples/tutorial/stable_diffusion/README.md b/examples/tutorial/stable_diffusion/README.md new file mode 100644 index 000000000..a0ece4485 --- /dev/null +++ b/examples/tutorial/stable_diffusion/README.md @@ -0,0 +1,149 @@ +# Stable Diffusion with Colossal-AI +*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and +fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* + +We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies +, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs. + +## 🚀Quick Start +1. Create a new environment for diffusion +```bash +conda env create -f environment.yaml +conda activate ldm +``` +2. Install Colossal-AI from our official page +```bash +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` +3. Install PyTorch Lightning compatible commit +```bash +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +cd .. +``` + +4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`. +5. Run training with CIFAR10. +```bash +python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml +``` + +## Stable Diffusion +[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion +model. +Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. +Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487), +this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. + +

+ +

+ +[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB). + +

+ +

+ +## Requirements +A suitable [conda](https://conda.io/) environment named `ldm` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch torchvision -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` + +### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website +``` +pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +``` + +### Install [Lightning](https://github.com/Lightning-AI/lightning) +We use the Sep. 2022 version with commit id as `b04a7aa`. +``` +git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa +pip install -r requirements.txt && pip install . +``` + +> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. + +## Dataset +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +you should the change the `data.file_path` in the `config/train_colossalai.yaml` + +## Training + +We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` + +For example, you can run the training from colossalai by +``` +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml +``` + +- you can change the `--logdir` the save the log information and the last checkpoint + +### Training config +You can change the trainging config in the yaml file + +- accelerator: acceleratortype, default 'gpu' +- devices: device number used for training, default 4 +- max_epochs: max training epochs +- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai + +## Example + +### Training on cifar10 + +We provide the finetuning example on CIFAR10 dataset + +You can run by config `train_colossalai_cifar10.yaml` +``` +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml +``` + + + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml new file mode 100644 index 000000000..c457787dd --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml @@ -0,0 +1,116 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: False + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: "/data/scratch/diffuser/laion_part0/" + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 000000000..63b9d1c01 --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,123 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml new file mode 100644 index 000000000..90d41258f --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml @@ -0,0 +1,113 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: False + train: + target: ldm.data.base.Txt2ImgIterableBaseDataset + params: + file_path: "/data/scratch/diffuser/laion_part0/" + world_size: 1 + rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: False + log_every_n_steps: 2 +# max_steps: 6o + logger: True + default_root_dir: "/tmp/diff_log/" + # profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml new file mode 100644 index 000000000..8b5d2adfa --- /dev/null +++ b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml @@ -0,0 +1,121 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + check_nan_inf: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 32 + wrap: False + train: + target: ldm.data.pokemon.PokemonDataset + # params: + # file_path: "/data/scratch/diffuser/laion_part0/" + # world_size: 1 + # rank: 0 + +lightning: + trainer: + accelerator: 'gpu' + devices: 4 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + initial_scale: 65536 + min_scale: 1 + max_scale: 65536 + # max_scale: 4294967296 + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/environment.yaml b/examples/tutorial/stable_diffusion/environment.yaml new file mode 100644 index 000000000..7d8aec86f --- /dev/null +++ b/examples/tutorial/stable_diffusion/environment.yaml @@ -0,0 +1,34 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.9.12 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.19.2 + - pip: + - albumentations==0.4.3 + - datasets + - diffusers + - opencv-python==4.6.0.66 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.8.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.19.2 + - torchmetrics==0.7.0 + - kornia==0.6 + - prefetch_generator + - colossalai + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e . diff --git a/examples/tutorial/stable_diffusion/ldm/data/__init__.py b/examples/tutorial/stable_diffusion/ldm/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/stable_diffusion/ldm/data/base.py b/examples/tutorial/stable_diffusion/ldm/data/base.py new file mode 100644 index 000000000..4f3cd3571 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/base.py @@ -0,0 +1,75 @@ +import math +from abc import abstractmethod + +import torch +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset +import os +import numpy as np +import cv2 + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, file_path: str, rank, world_size): + super().__init__() + self.file_path = file_path + self.folder_list = [] + self.file_list = [] + self.txt_list = [] + self.info = self._get_file_info(file_path) + self.start = self.info['start'] + self.end = self.info['end'] + self.rank = rank + + self.world_size = world_size + # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size))) + # self.iter_start = self.start + self.rank * self.per_worker + # self.iter_end = min(self.iter_start + self.per_worker, self.end) + # self.num_records = self.iter_end - self.iter_start + # self.valid_ids = [i for i in range(self.iter_end)] + self.num_records = self.end - self.start + self.valid_ids = [i for i in range(self.end)] + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + # return self.iter_end - self.iter_start + return self.end - self.start + + def __iter__(self): + sample_iterator = self._sample_generator(self.start, self.end) + # sample_iterator = self._sample_generator(self.iter_start, self.iter_end) + return sample_iterator + + def _sample_generator(self, start, end): + for idx in range(start, end): + file_name = self.file_list[idx] + txt_name = self.txt_list[idx] + f_ = open(txt_name, 'r') + txt_ = f_.read() + f_.close() + image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image) / 255 + yield {"caption": txt_, "image":image} + + + def _get_file_info(self, file_path): + info = \ + { + "start": 1, + "end": 0, + } + self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + for folder in self.folder_list: + files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] + txts = [k.replace('jpg', 'txt') for k in files] + self.file_list.extend(files) + self.txt_list.extend(txts) + info['end'] = len(self.file_list) + # with open(file_path, 'r') as fin: + # for _ in enumerate(fin): + # info['end'] += 1 + # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list] + return info \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py new file mode 100644 index 000000000..53cd61263 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py new file mode 100644 index 000000000..1c473f9c6 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/examples/tutorial/stable_diffusion/ldm/data/lsun.py b/examples/tutorial/stable_diffusion/ldm/data/lsun.py new file mode 100644 index 000000000..6256e4571 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py new file mode 100644 index 000000000..be39da9ca --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py new file mode 100644 index 000000000..873d8b69b --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py @@ -0,0 +1,544 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + from_pretrained: str=None + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + from diffusers.modeling_utils import load_state_dict + if from_pretrained is not None: + state_dict = load_state_dict(from_pretrained) + self._load_pretrained_model(state_dict) + + def _state_key_mapping(self, state_dict: dict): + import re + res_dict = {} + key_list = state_dict.keys() + key_str = " ".join(key_list) + up_block_pattern = re.compile('upsamplers') + p1 = re.compile('mid.block_[0-9]') + p2 = re.compile('decoder.up.[0-9]') + up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1) + for key_, val_ in state_dict.items(): + key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\ + .replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\ + .replace('mid.attentions.0.key', 'mid.attn_1.k')\ + .replace('mid.attentions.0.query', 'mid.attn_1.q') \ + .replace('mid.attentions.0.value', 'mid.attn_1.v') \ + .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \ + .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\ + .replace('upsamplers.0', 'upsample')\ + .replace('downsamplers.0', 'downsample')\ + .replace('conv_shortcut', 'nin_shortcut')\ + .replace('conv_norm_out', 'norm_out') + + mid_list = re.findall(p1, key_) + if len(mid_list) != 0: + mid_str = mid_list[0] + mid_id = int(mid_str[-1]) + 1 + key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id)) + + up_list = re.findall(p2, key_) + if len(up_list) != 0: + up_str = up_list[0] + up_id = up_blocks_count - 1 -int(up_str[-1]) + key_ = key_.replace(up_str, up_str[:-1] + str(up_id)) + res_dict[key_] = val_ + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 000000000..67e98b9d8 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 000000000..91335d637 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,240 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 000000000..9633ec3d8 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1554 @@ +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.x_transformer import * +from ldm.modules.encoders.modules import * + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * + + +from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder + +from ldm.util import instantiate_from_config + +from einops import rearrange, repeat + + + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16 = True, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.unet_config = unet_config + self.conditioning_key = conditioning_key + # self.model = DiffusionWrapper(unet_config, conditioning_key) + # count_params(self.model, verbose=True) + self.use_ema = use_ema + # if self.use_ema: + # self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.ckpt_path = ckpt_path + self.ignore_keys = ignore_keys + self.load_only_unet = load_only_unet + self.given_betas = given_betas + self.beta_schedule = beta_schedule + self.timesteps = timesteps + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + # if ckpt_path is not None: + # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + # + # self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + # linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar_init = logvar_init + # self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + # if self.learn_logvar: + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.use_fp16 = use_fp16 + if use_fp16: + self.unet_config["params"].update({"use_fp16": True}) + rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) + else: + self.unet_config["params"].update({"use_fp16": False}) + rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + + if pred.isnan().any(): + print("Warning: Prediction has nan values") + lr = self.optimizers().param_groups[0]['lr'] + # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + print(f"lr: {lr}") + if pred.isinf().any(): + print("Warning: Prediction has inf values") + + if self.use_fp16: + target = target.half() + + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + if loss.isnan().any(): + print("Warning: loss has nan values") + print("loss: ", loss[0][0][0]) + raise ValueError("loss has nan values") + if loss.isinf().any(): + print("Warning: loss has inf values") + print("loss: ", loss) + raise ValueError("loss has inf values") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # print("+" * 30) + # print(batch['jpg'].shape) + # print(len(batch['txt'])) + # print(k) + # print("=" * 30) + if not isinstance(batch, torch.Tensor): + x = batch[k] + else: + x = batch + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + + if self.use_fp16: + x = x.to(memory_format=torch.contiguous_format).float().half() + else: + x = x.to(memory_format=torch.contiguous_format).float() + + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, use_fp16=use_fp16, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + if self.use_fp16: + self.cond_stage_config["params"].update({"use_fp16": True}) + rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) + else: + self.cond_stage_config["params"].update({"use_fp16": False}) + rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) + # self.instantiate_first_stage(first_stage_config) + # self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + + + def configure_sharded_model(self) -> None: + self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + count_params(self.model, verbose=True) + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + + self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) + + self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + # self.logvar = nn.Parameter(self.logvar, requires_grad=True) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + + # TODO() + # for p in self.model.modules(): + # if not p.parameters().data.is_contiguous: + # p.data = p.data.contiguous() + + self.instantiate_first_stage(self.first_stage_config) + self.instantiate_cond_stage(self.cond_stage_config) + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + + + @rank_zero_only + @torch.no_grad() + # def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', 'txt']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + rank_zero_info("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 000000000..78eeb1003 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/examples/tutorial/stable_diffusion/ldm/modules/attention.py b/examples/tutorial/stable_diffusion/ldm/modules/attention.py new file mode 100644 index 000000000..3401ceafd --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/attention.py @@ -0,0 +1,314 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from torch.utils import checkpoint + +try: + from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv + FlASH_AVAILABLE = True +except: + FlASH_AVAILABLE = False + +USE_FLASH = False + + +def enable_flash_attention(): + global USE_FLASH + USE_FLASH = True + if FlASH_AVAILABLE is False: + print("Please install flash attention to activate new attention kernel.\n" + + "Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'") + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + dim_head = q.shape[-1] / self.heads + + if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \ + dim_head <= 128 and (dim_head % 8) == 0: + # print("in flash") + if q.shape[1] == k.shape[1]: + out = self._flash_attention_qkv(q, k, v) + else: + out = self._flash_attention_q_kv(q, k, v) + else: + out = self._native_attention(q, k, v, self.heads, mask) + + return self.to_out(out) + + def _native_attention(self, q, k, v, h, mask): + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + # attention, what we cannot get enough of + out = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', out, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return out + + def _flash_attention_qkv(self, q, k, v): + qkv = torch.stack([q, k, v], dim=2) + b = qkv.shape[0] + n = qkv.shape[1] + qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads) + out = flash_attention_qkv(qkv, self.scale, b, n) + out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) + return out + + def _flash_attention_q_kv(self, q, k, v): + kv = torch.stack([k, v], dim=2) + b = q.shape[0] + q_seqlen = q.shape[1] + kv_seqlen = kv.shape[1] + q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads) + kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads) + out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen) + out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) + return out + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.use_checkpoint = use_checkpoint + + def forward(self, x, context=None): + + + if self.use_checkpoint: + return checkpoint(self._forward, x, context) + else: + return self._forward(x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, use_checkpoint=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + x = x.contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = x.contiguous() + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 000000000..3c28492c5 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,862 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + +class temb_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + # self.temb = nn.Module() + self.temb = temb_module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # down = nn.Module() + down = Down_module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # up = nn.Module() + up = Up_module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + +class Down_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Up_module(nn.Module): + def __init__(self): + super().__init__() + pass + +class Mid_module(nn.Module): + def __init__(self): + super().__init__() + pass + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # down = nn.Module() + down = Down_module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + # self.mid = nn.Module() + self.mid = Mid_module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + # up = nn.Module() + up = Up_module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 000000000..3aedc2205 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1152 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.utils import checkpoint + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + # for n,p in x.named_parameter(): + # print(f"convert module {n} to_f16") + # p.data = p.data.half() + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + if self.use_checkpoint: + return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + else: + return self._forward(x) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + from_pretrained: str=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + # if use_fp16: + # self.convert_to_fp16() + from diffusers.modeling_utils import load_state_dict + if from_pretrained is not None: + state_dict = load_state_dict(from_pretrained) + self._load_pretrained_model(state_dict) + + def _input_blocks_mapping(self, input_dict): + res_dict = {} + for key_, value_ in input_dict.items(): + id_0 = int(key_[13]) + if "resnets" in key_: + id_1 = int(key_[23]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\ + .replace('norm1', 'in_layers.0')\ + .replace('norm2', 'out_layers.0')\ + .replace('conv1', 'in_layers.2')\ + .replace('conv2', 'out_layers.3')\ + .replace('conv_shortcut', 'skip_connection') + res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[26]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = key_[28:] + res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "downsamplers" in key_: + post_fix = key_[35:] + target_id = 3 * (id_0 + 1) + res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_ + return res_dict + + + def _mid_blocks_mapping(self, mid_dict): + res_dict = {} + for key_, value_ in mid_dict.items(): + if "resnets" in key_: + temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \ + .replace('norm1', 'in_layers.0') \ + .replace('norm2', 'out_layers.0') \ + .replace('conv1', 'in_layers.2') \ + .replace('conv2', 'out_layers.3') \ + .replace('conv_shortcut', 'skip_connection')\ + .replace('middle_block.resnets.0', 'middle_block.0')\ + .replace('middle_block.resnets.1', 'middle_block.2') + res_dict[temp_key_] = value_ + elif "attentions" in key_: + res_dict[key_.replace('attentions.0', '1')] = value_ + return res_dict + + def _other_blocks_mapping(self, other_dict): + res_dict = {} + for key_, value_ in other_dict.items(): + tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\ + .replace('time_embedding.linear_1', 'time_embed.0')\ + .replace('time_embedding.linear_2', 'time_embed.2')\ + .replace('conv_norm_out', 'out.0')\ + .replace('conv_out', 'out.2') + res_dict[tmp_key] = value_ + return res_dict + + + def _output_blocks_mapping(self, output_dict): + res_dict = {} + for key_, value_ in output_dict.items(): + id_0 = int(key_[14]) + if "resnets" in key_: + id_1 = int(key_[24]) + target_id = 3 * id_0 + id_1 + post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \ + .replace('norm1', 'in_layers.0') \ + .replace('norm2', 'out_layers.0') \ + .replace('conv1', 'in_layers.2') \ + .replace('conv2', 'out_layers.3') \ + .replace('conv_shortcut', 'skip_connection') + res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[27]) + target_id = 3 * id_0 + id_1 + post_fix = key_[29:] + res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "upsamplers" in key_: + post_fix = key_[34:] + target_id = 3 * (id_0 + 1) - 1 + mid_str = '.2.conv.' if target_id != 2 else '.1.conv.' + res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_ + return res_dict + + def _state_key_mapping(self, state_dict: dict): + import re + res_dict = {} + input_dict = {} + mid_dict = {} + output_dict = {} + other_dict = {} + for key_, value_ in state_dict.items(): + if "down_blocks" in key_: + input_dict[key_.replace('down_blocks', 'input_blocks')] = value_ + elif "up_blocks" in key_: + output_dict[key_.replace('up_blocks', 'output_blocks')] = value_ + elif "mid_block" in key_: + mid_dict[key_.replace('mid_block', 'middle_block')] = value_ + else: + other_dict[key_] = value_ + + input_dict = self._input_blocks_mapping(input_dict) + output_dict = self._output_blocks_mapping(output_dict) + mid_dict = self._mid_blocks_mapping(mid_dict) + other_dict = self._other_blocks_mapping(other_dict) + # key_list = state_dict.keys() + # key_str = " ".join(key_list) + + # for key_, val_ in state_dict.items(): + # key_ = key_.replace("down_blocks", "input_blocks")\ + # .replace("up_blocks", 'output_blocks') + # res_dict[key_] = val_ + res_dict.update(input_dict) + res_dict.update(output_dict) + res_dict.update(mid_dict) + res_dict.update(other_dict) + + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(self.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(self.dtype) + return self.out(h) + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 000000000..a7db9369c --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,276 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + if use_fp16: + return embedding.half() + else: + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels, precision=16): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + if precision == 16: + return GroupNorm16(16, channels) + else: + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +class GroupNorm16(nn.GroupNorm): + def forward(self, x): + return super().forward(x.half()).type(x.dtype) + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 000000000..f2b8ef901 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/examples/tutorial/stable_diffusion/ldm/modules/ema.py b/examples/tutorial/stable_diffusion/ldm/modules/ema.py new file mode 100644 index 000000000..c8c75af43 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 000000000..8cfc01e5d --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,264 @@ +import types + +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig +import kornia +from transformers.models.clip.modeling_clip import CLIPTextTransformer + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class CLIPTextModelZero(CLIPTextModel): + config_class = CLIPTextConfig + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformerZero(config) + +class CLIPTextTransformerZero(CLIPTextTransformer): + def _build_causal_attention_mask(self, bsz, seq_len): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask.half() + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + + if use_fp16: + self.transformer = CLIPTextModelZero.from_pretrained(version) + else: + self.transformer = CLIPTextModel.from_pretrained(version) + + # print(self.transformer.modules()) + # print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype)) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + # tokens = batch_encoding["input_ids"].to(self.device) + tokens = batch_encoding["input_ids"].to(self.device) + # print("token type: {}".format(tokens.dtype)) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py new file mode 100644 index 000000000..2a7a73879 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py @@ -0,0 +1,50 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) +""" + +import torch +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func +except ImportError: + raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention') + + + +def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len): + """ + Arguments: + qkv: (batch*seq, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, + device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, 0.0, + softmax_scale=sm_scale, causal=False + ) + return out + + +def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen): + """ + Arguments: + q: (batch*seq, nheads, headdim) + kv: (batch*seq, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale) + return out diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 000000000..7836cada8 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 000000000..32ef56169 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 000000000..9e1f82399 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 000000000..4249b43de Binary files /dev/null and b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 000000000..0175f155a --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py new file mode 100644 index 000000000..876d7c5bd --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py new file mode 100644 index 000000000..672c1e32a --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py new file mode 100644 index 000000000..f69981769 --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py b/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py new file mode 100644 index 000000000..5fc15bf9c --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/examples/tutorial/stable_diffusion/ldm/util.py b/examples/tutorial/stable_diffusion/ldm/util.py new file mode 100644 index 000000000..8ba38853e --- /dev/null +++ b/examples/tutorial/stable_diffusion/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/examples/tutorial/stable_diffusion/main.py b/examples/tutorial/stable_diffusion/main.py new file mode 100644 index 000000000..7cd00e4c0 --- /dev/null +++ b/examples/tutorial/stable_diffusion/main.py @@ -0,0 +1,830 @@ +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image +# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy +# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from prefetch_generator import BackgroundGenerator + +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info +from diffusers.models.unet_2d import UNet2DModel + +from clip.model import Bottleneck +from transformers.models.clip.modeling_clip import CLIPTextTransformer + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import * +from ldm.modules.encoders.modules import * +from taming.modules.diffusionmodules.model import ResnetBlock +from taming.modules.transformer.mingpt import * +from taming.modules.transformer.permuter import * + + +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import AutoencoderKL +from ldm.models.autoencoder import * +from ldm.models.diffusion.ddim import * +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.model import * +from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module +from ldm.modules.attention import enable_flash_attention + +class DataLoaderX(DataLoader): + + def __iter__(self): + return BackgroundGenerator(super().__iter__()) + + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project" + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + parser.add_argument( + "--use_fp16", + type=str2bool, + nargs="?", + const=True, + default=True, + help="whether to use fp16", + ) + parser.add_argument( + "--flash", + type=str2bool, + const=True, + default=False, + nargs="?", + help="whether to use flash attention", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, + wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict( + (k, instantiate_from_config(self.dataset_configs[k])) + for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["train"], batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoaderX(self.datasets["test"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn) + + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + # def on_pretrain_routine_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.CSVLogger: self._testtube, + } + self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image( + tag, grid, + global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, + global_step, + current_epoch, + batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + # self.log_img(pl_module, batch, batch_idx, split="train") + pass + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + + def on_train_start(self, trainer, pl_module): + rank_zero_info("Training is starting") + + def on_train_end(self, trainer, pl_module): + rank_zero_info("Training is ending") + + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.strategy.reduce(max_memory) + epoch_time = trainer.strategy.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) + if opt.flash: + enable_flash_attention() + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + + print(trainer_config) + if not trainer_config["accelerator"] == "gpu": + del trainer_config["accelerator"] + cpu = True + print("Running on CPU") + else: + cpu = False + print("Running on GPU") + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + use_fp16 = trainer_config.get("precision", 32) == 16 + if use_fp16: + config.model["params"].update({"use_fp16": True}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + else: + config.model["params"].update({"use_fp16": False}) + print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + + model = instantiate_from_config(config.model) + # trainer and callbacks + trainer_kwargs = dict() + + # config the logger + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "tensorboard":{ + "target": "pytorch_lightning.loggers.TensorBoardLogger", + "params":{ + "save_dir": logdir, + "name": "diff_tb", + "log_graph": True + } + } + } + + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = default_logger_cfg + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # config the strategy, defualt is ddp + if "strategy" in trainer_config: + strategy_cfg = trainer_config["strategy"] + print("Using strategy: {}".format(strategy_cfg["target"])) + else: + strategy_cfg = { + "target": "pytorch_lightning.strategies.DDPStrategy", + "params": { + "find_unused_parameters": False + } + } + print("Using strategy: DDPStrategy") + + trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': + {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = trainer_config["devices"] + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb; + pudb.set_trace() + + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + for name, m in model.named_parameters(): + print(name) + trainer.fit(model, data) + except Exception: + melk() + raise + # if not opt.no_test and not trainer.interrupted: + # trainer.test(model, data) + except Exception: + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # move newly created debug project to debug_runs + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/examples/tutorial/stable_diffusion/requirements.txt b/examples/tutorial/stable_diffusion/requirements.txt new file mode 100644 index 000000000..a57003562 --- /dev/null +++ b/examples/tutorial/stable_diffusion/requirements.txt @@ -0,0 +1,22 @@ +albumentations==0.4.3 +diffusers +pudb==2019.2 +datasets +invisible-watermark +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +omegaconf==2.1.1 +multiprocess +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +torch-fidelity==0.3.0 +transformers==4.19.2 +torchmetrics==0.6.0 +kornia==0.6 +opencv-python==4.6.0.66 +prefetch_generator +colossalai +-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +-e git+https://github.com/openai/CLIP.git@main#egg=clip +-e . diff --git a/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh b/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh new file mode 100644 index 000000000..a8d79e99c --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/download_models.sh b/examples/tutorial/stable_diffusion/scripts/download_models.sh new file mode 100644 index 000000000..84297d7b8 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/examples/tutorial/stable_diffusion/scripts/img2img.py b/examples/tutorial/stable_diffusion/scripts/img2img.py new file mode 100644 index 000000000..421e2151d --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/img2img.py @@ -0,0 +1,293 @@ +"""make variations of input image""" + +import argparse, os, sys, glob +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +import time +from pytorch_lightning import seed_everything + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save indiviual samples. For speed measurements.", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.75, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + raise NotImplementedError("PLMS sampler not (yet) supported") + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not opt.skip_save: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/stable_diffusion/scripts/inpaint.py b/examples/tutorial/stable_diffusion/scripts/inpaint.py new file mode 100644 index 000000000..d6e6387a9 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/tutorial/stable_diffusion/scripts/knn2img.py b/examples/tutorial/stable_diffusion/scripts/knn2img.py new file mode 100644 index 000000000..e6eaaecab --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py b/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py new file mode 100644 index 000000000..876fe3c36 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py b/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py new file mode 100644 index 000000000..a32e66d44 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py @@ -0,0 +1,37 @@ +import os +import sys +from copy import deepcopy + +import yaml +from datetime import datetime + +from diffusers import StableDiffusionPipeline +import torch +from ldm.util import instantiate_from_config +from main import get_parser + +if __name__ == "__main__": + with torch.no_grad(): + yaml_path = "../../train_colossalai.yaml" + with open(yaml_path, 'r', encoding='utf-8') as f: + config = f.read() + base_config = yaml.load(config, Loader=yaml.FullLoader) + unet_config = base_config['model']['params']['unet_config'] + diffusion_model = instantiate_from_config(unet_config).to("cuda:0") + + pipe = StableDiffusionPipeline.from_pretrained( + "/data/scratch/diffuser/stable-diffusion-v1-4" + ).to("cuda:0") + dif_model_2 = pipe.unet + + random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") + random_input_2 = torch.clone(random_input_).to("cuda:0") + time_stamp = torch.randint(20, (4,)).to("cuda:0") + time_stamp2 = torch.clone(time_stamp).to("cuda:0") + context_ = torch.rand((4, 77, 768)).to("cuda:0") + context_2 = torch.clone(context_).to("cuda:0") + + out_1 = diffusion_model(random_input_, time_stamp, context_) + out_2 = dif_model_2(random_input_2, time_stamp2, context_2) + print(out_1.shape) + print(out_2['sample'].shape) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py b/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py new file mode 100644 index 000000000..f93f8a6e7 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/train_searcher.py b/examples/tutorial/stable_diffusion/scripts/train_searcher.py new file mode 100644 index 000000000..1e7904889 --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/scripts/txt2img.py b/examples/tutorial/stable_diffusion/scripts/txt2img.py new file mode 100644 index 000000000..59c16a1db --- /dev/null +++ b/examples/tutorial/stable_diffusion/scripts/txt2img.py @@ -0,0 +1,344 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/stable_diffusion/setup.py b/examples/tutorial/stable_diffusion/setup.py new file mode 100644 index 000000000..a24d54167 --- /dev/null +++ b/examples/tutorial/stable_diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/examples/tutorial/stable_diffusion/train.sh b/examples/tutorial/stable_diffusion/train.sh new file mode 100644 index 000000000..63abcadbf --- /dev/null +++ b/examples/tutorial/stable_diffusion/train.sh @@ -0,0 +1,4 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 + +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml diff --git a/inference b/inference index 98a12bc21..6dadc2a4f 160000 --- a/inference +++ b/inference @@ -1 +1 @@ -Subproject commit 98a12bc2107b206017c4793380538f9cdec5a5e1 +Subproject commit 6dadc2a4f293f4314280d6250463d986536e46ea diff --git a/op_builder/README.md b/op_builder/README.md new file mode 100644 index 000000000..057da1038 --- /dev/null +++ b/op_builder/README.md @@ -0,0 +1,31 @@ +# Build PyTorch Extensions + +## Overview + +Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. + +1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` +2. Build the extension during runtime + +The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. + +These two methods have different advantages and disadvantages. +Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. +Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. + +## PyTorch Extensions in Colossal-AI + +As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime. +There are mainly two functions used to build extensions. + +1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. +2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime + +Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). + +We have implemented the following conventions: + +1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` +2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) + +When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py new file mode 100644 index 000000000..5ae7223b8 --- /dev/null +++ b/op_builder/__init__.py @@ -0,0 +1,23 @@ +from .cpu_adam import CPUAdamBuilder +from .fused_optim import FusedOptimBuilder +from .layernorm import LayerNormBuilder +from .moe import MOEBuilder +from .multi_head_attn import MultiHeadAttnBuilder +from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder +from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder + +ALL_OPS = { + 'cpu_adam': CPUAdamBuilder, + 'fused_optim': FusedOptimBuilder, + 'moe': MOEBuilder, + 'multi_head_attn': MultiHeadAttnBuilder, + 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, + 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, + 'layernorm': LayerNormBuilder, +} + +__all__ = [ + 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', + 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', + 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' +] diff --git a/op_builder/builder.py b/op_builder/builder.py new file mode 100644 index 000000000..dc9ea8e11 --- /dev/null +++ b/op_builder/builder.py @@ -0,0 +1,165 @@ +import importlib +import os +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List + + +class Builder(ABC): + """ + Builder is the base class to build extensions for PyTorch. + + Args: + name (str): the name of the kernel to be built + prebuilt_import_path (str): the path where the extension is installed during pip install + """ + + def __init__(self, name: str, prebuilt_import_path: str): + self.name = name + self.prebuilt_import_path = prebuilt_import_path + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + assert prebuilt_import_path.startswith('colossalai._C'), \ + f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + op_builder_module_path = Path(__file__).parent + + # if we install from source + # the current file path will be op_builder/builder.py + # if we install via pip install colossalai + # the current file path will be colossalai/kernel/op_builder/builder.py + # this is because that the op_builder inside colossalai is a symlink + # this symlink will be replaced with actual files if we install via pypi + # thus we cannot tell the colossalai root directory by checking whether the op_builder + # is a symlink, we can only tell whether it is inside or outside colossalai + if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + root_path = op_builder_module_path.parent.parent + else: + root_path = op_builder_module_path.parent.joinpath('colossalai') + + code_abs_path = root_path.joinpath(code_path) + return str(code_abs_path) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + raise NotImplementedError + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of inlcude files for extensions. + """ + pass + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + pass + + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + pass + + # functions must be overrided over + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def load(self, verbose=True): + """ + load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. + If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the + kernel is built during pip install, it can be accessed through `colossalai._C`. + + Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + from torch.utils.cpp_extension import load + start_build = time.time() + + try: + op_module = self.import_op() + if verbose: + print(f"OP {self.prebuilt_import_path} already exists, skip building.") + except ImportError: + # construct the build directory + import torch + torch_version_major = torch.__version__.split('.')[0] + torch_version_minor = torch.__version__.split('.')[1] + torch_cuda_version = torch.version.cuda + home_directory = os.path.expanduser('~') + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" + build_directory = os.path.join(home_directory, extension_directory) + Path(build_directory).mkdir(parents=True, exist_ok=True) + + if verbose: + print("=========================================================================================") + print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now") + print("=========================================================================================") + + # load the kernel + op_module = load(name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + + return op_module + + def builder(self) -> 'CUDAExtension': + """ + get a CUDAExtension instance used for setup.py + """ + from torch.utils.cpp_extension import CUDAExtension + + return CUDAExtension(name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + 'cxx': self.strip_empty_entries(self.cxx_flags()), + 'nvcc': self.strip_empty_entries(self.nvcc_flags()) + }) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py new file mode 100644 index 000000000..500e2cc0e --- /dev/null +++ b/op_builder/cpu_adam.py @@ -0,0 +1,38 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads + + +class CPUAdamBuilder(Builder): + NAME = "cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" + + def __init__(self): + super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path('cpu_adam.cpp'), + ] + return ret + + def include_dirs(self): + return [ + self.csrc_abs_path("includes"), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py new file mode 100644 index 000000000..31ddfced1 --- /dev/null +++ b/op_builder/fused_optim.py @@ -0,0 +1,34 @@ +import os + +from .builder import Builder +from .utils import get_cuda_cc_flag + + +class FusedOptimBuilder(Builder): + NAME = "fused_optim" + PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" + + def __init__(self): + super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', + 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + ] + ] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + return ['-O3'] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ['-O3', '--use_fast_math'] + extra_cuda_flags diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py new file mode 100644 index 000000000..61d941741 --- /dev/null +++ b/op_builder/layernorm.py @@ -0,0 +1,29 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormBuilder(Builder): + NAME = "layernorm" + PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" + + def __init__(self): + super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py new file mode 100644 index 000000000..eeb7d8e39 --- /dev/null +++ b/op_builder/moe.py @@ -0,0 +1,36 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class MOEBuilder(Builder): + + NAME = "moe" + PREBUILT_IMPORT_PATH = "colossalai._C.moe" + + def __init__(self): + super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + return ret + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py new file mode 100644 index 000000000..f9103fe94 --- /dev/null +++ b/op_builder/multi_head_attn.py @@ -0,0 +1,41 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class MultiHeadAttnBuilder(Builder): + + NAME = "multihead_attention" + PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" + + def __init__(self): + super().__init__(name=MultiHeadAttnBuilder.NAME, + prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', + 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', + 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py new file mode 100644 index 000000000..11cfda39a --- /dev/null +++ b/op_builder/scaled_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads + + +class ScaledMaskedSoftmaxBuilder(Builder): + NAME = "scaled_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] + ] + return ret + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py new file mode 100644 index 000000000..d0d2433aa --- /dev/null +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): + NAME = "scaled_upper_triangle_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py new file mode 100644 index 000000000..b6bada99e --- /dev/null +++ b/op_builder/utils.py @@ -0,0 +1,42 @@ +import re +import subprocess +from typing import List + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + +def get_cuda_cc_flag() -> List: + """get_cuda_cc_flag + + cc flag for your GPU arch + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60: + cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + return cc_flag + +def append_nvcc_threads(nvcc_extra_args): + from torch.utils.cpp_extension import CUDA_HOME + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 7fd805c14..9ef0a682b 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,9 +1,13 @@ -diffusers +fbgemm-gpu==0.2.0 pytest +pytest-cov torchvision transformers timm titans torchaudio -torchrec +torchrec==0.2.0 contexttimer +einops +triton==2.0.0.dev20221011 +git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 528bc6f25..cc99257a9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,3 @@ -torch>=1.8 numpy tqdm psutil @@ -7,4 +6,5 @@ pre-commit rich click fabric -contexttimer \ No newline at end of file +contexttimer +ninja diff --git a/setup.py b/setup.py index 8341a97b7..5128b80e8 100644 --- a/setup.py +++ b/setup.py @@ -1,26 +1,44 @@ import os -import subprocess import re -from setuptools import find_packages, setup, Extension +from datetime import datetime + +from setuptools import find_packages, setup + +from op_builder.utils import get_cuda_bare_metal_version + +try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10): + raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/") + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + CUDA_HOME = None # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -build_cuda_ext = True +build_cuda_ext = False ext_modules = [] +is_nightly = int(os.environ.get('NIGHTLY', '0')) == 1 -if int(os.environ.get('NO_CUDA_EXT', '0')) == 1: - build_cuda_ext = False +if int(os.environ.get('CUDA_EXT', '0')) == 1: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) + if not CUDA_HOME: + raise RuntimeError( + "CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions" + ) -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor + build_cuda_ext = True def check_cuda_torch_binary_vs_bare_metal(cuda_dir): @@ -92,29 +110,24 @@ def fetch_readme(): def get_version(): - with open('version.txt') as f: + setup_file_path = os.path.abspath(__file__) + project_path = os.path.dirname(setup_file_path) + version_txt_path = os.path.join(project_path, 'version.txt') + version_py_path = os.path.join(project_path, 'colossalai/version.py') + + with open(version_txt_path) as f: version = f.read().strip() if build_cuda_ext: torch_version = '.'.join(torch.__version__.split('.')[:2]) cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:]) version += f'+torch{torch_version}cu{cuda_version}' - return version + # write version into version.py + with open(version_py_path, 'w') as f: + f.write(f"__version__ = '{version}'\n") -if build_cuda_ext: - try: - import torch - from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CUDAExtension) - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + return version - if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 8): - raise RuntimeError("Colossal-AI requires Pytorch 1.8 or newer.\n" - "The latest stable release can be obtained from https://pytorch.org/") - except ImportError: - print('torch is not found. CUDA extension will not be installed') - build_cuda_ext = False if build_cuda_ext: build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) @@ -125,111 +138,63 @@ if build_cuda_ext: # and # https://github.com/NVIDIA/apex/issues/456 # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac - version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources], - include_dirs=[os.path.join(this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags, - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) - }) + from op_builder import ALL_OPS + for name, builder_cls in ALL_OPS.items(): + print(f'===== Building Extension {name} =====') + ext_modules.append(builder_cls().builder()) - cc_flag = [] - for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 60: - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) +if is_nightly: + # use date as the nightly version + version = datetime.today().strftime('%Y.%m.%d') + package_name = 'colossalai-nightly' +else: + version = get_version() + package_name = 'colossalai' - extra_cuda_flags = ['-lineinfo'] - - ext_modules.append( - cuda_ext_helper('colossal_C', [ - 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', - 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' - ], extra_cuda_flags + cc_flag)) - - extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' - ] - - ext_modules.append( - cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax', - ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'], - extra_cuda_flags + cc_flag)) - - ext_modules.append( - cuda_ext_helper('colossal_scaled_masked_softmax', - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) - - ext_modules.append( - cuda_ext_helper('colossal_moe_cuda', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) - - extra_cuda_flags = ['-maxrregcount=50'] - - ext_modules.append( - cuda_ext_helper('colossal_layer_norm_cuda', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], - extra_cuda_flags + cc_flag)) - - extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', - '-DTHRUST_IGNORE_CUB_VERSION_CHECK' - ] - - ext_modules.append( - cuda_ext_helper('colossal_multihead_attention', [ - 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', - 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', - 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' - ], extra_cuda_flags + cc_flag)) - - extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - ext_modules.append(cuda_ext_helper('cpu_adam', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags)) - -setup( - name='colossalai', - version=get_version(), - packages=find_packages(exclude=( - 'benchmark', - 'docker', - 'tests', - 'docs', - 'examples', - 'tests', - 'scripts', - 'requirements', - '*.egg-info', - )), - description='An integrated large-scale model training system with efficient parallelization techniques', - long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://www.colossalai.org', - project_urls={ - 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', - 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', - 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', - 'Documentation': 'http://colossalai.readthedocs.io', - 'Github': 'https://github.com/hpcaitech/ColossalAI', - }, - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - install_requires=fetch_requirements('requirements/requirements.txt'), - entry_points=''' +setup(name=package_name, + version=version, + packages=find_packages(exclude=( + 'benchmark', + 'docker', + 'tests', + 'docs', + 'examples', + 'tests', + 'scripts', + 'requirements', + '*.egg-info', + )), + description='An integrated large-scale model training system with efficient parallelization techniques', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://www.colossalai.org', + project_urls={ + 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', + 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', + 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', + 'Documentation': 'http://colossalai.readthedocs.io', + 'Github': 'https://github.com/hpcaitech/ColossalAI', + }, + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements('requirements/requirements.txt'), + entry_points=''' [console_scripts] colossalai=colossalai.cli:cli ''', - python_requires='>=3.6', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', - ], -) + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], + package_data={ + 'colossalai': [ + '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', + 'kernel/cuda_native/csrc/kernels/include/*' + ] + }) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index f87d35ff9..106f4e61c 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1,19 @@ -from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt +from . import ( + beit, + bert, + gpt2, + hanging_param_model, + inline_op_model, + nested_model, + repeated_computed_layers, + resnet, + simple_net, +) +from .utils import run_fwd_bwd + +from . import albert # isort:skip + +__all__ = [ + 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', + 'simple_net', 'run_fwd_bwd', 'albert', 'beit' +] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py new file mode 100644 index 000000000..d5b6bc89a --- /dev/null +++ b/tests/components_to_test/albert.py @@ -0,0 +1,59 @@ +import torch +import transformers +from packaging import version +from transformers import AlbertConfig, AlbertForSequenceClassification + +from .bert import get_bert_data_loader +from .registry import non_distributed_component_funcs + + +@non_distributed_component_funcs.register(name='albert') +def get_training_components(): + hidden_dim = 8 + num_head = 4 + sequence_length = 12 + num_layer = 2 + vocab_size = 32 + + def bert_model_builder(checkpoint: bool = False): + config = AlbertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0., + attention_probs_dropout_prob=0.) + print('building AlbertForSequenceClassification model') + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(AlbertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) + # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): + # model.gradient_checkpointing_enable() + + return model + + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distrbuted=is_distrbuted) + + criterion = None + return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py new file mode 100644 index 000000000..1252071f4 --- /dev/null +++ b/tests/components_to_test/beit.py @@ -0,0 +1,42 @@ +import torch +from timm.models.beit import Beit + +from colossalai.utils.cuda import get_current_device + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class DummyDataLoader(DummyDataGenerator): + img_size = 64 + num_channel = 3 + num_class = 10 + batch_size = 4 + + def generate(self): + data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, + DummyDataLoader.img_size), + device=get_current_device()) + label = torch.randint(low=0, + high=DummyDataLoader.num_class, + size=(DummyDataLoader.batch_size,), + device=get_current_device()) + return data, label + + +@non_distributed_component_funcs.register(name='beit') +def get_training_components(): + + def model_buider(checkpoint=False): + model = Beit(img_size=DummyDataLoader.img_size, + num_classes=DummyDataLoader.num_class, + embed_dim=32, + depth=2, + num_heads=4) + return model + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_buider, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index e8d202b69..c1faa6f9d 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -8,6 +8,7 @@ from .registry import non_distributed_component_funcs def get_bert_data_loader( + n_class, batch_size, total_samples, sequence_length, @@ -16,7 +17,7 @@ def get_bert_data_loader( ): train_data = torch.randint( low=0, - high=1000, + high=n_class, size=(total_samples, sequence_length), device=device, dtype=torch.long, @@ -37,9 +38,9 @@ def get_training_components(): num_head = 4 sequence_length = 12 num_layer = 2 - vocab_size = 30524 + vocab_size = 32 - def bert_model_builder(checkpoint): + def bert_model_builder(checkpoint: bool = False): config = BertConfig(vocab_size=vocab_size, gradient_checkpointing=checkpoint, hidden_size=hidden_dim, @@ -67,14 +68,17 @@ def get_training_components(): return model - trainloader = get_bert_data_loader(batch_size=2, + is_distrbuted = torch.distributed.is_initialized() + trainloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, total_samples=10000, sequence_length=sequence_length, - is_distrbuted=True) - testloader = get_bert_data_loader(batch_size=2, + is_distrbuted=is_distrbuted) + testloader = get_bert_data_loader(n_class=vocab_size, + batch_size=2, total_samples=10000, sequence_length=sequence_length, - is_distrbuted=True) + is_distrbuted=is_distrbuted) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt.py b/tests/components_to_test/gpt2.py similarity index 96% rename from tests/components_to_test/gpt.py rename to tests/components_to_test/gpt2.py index 3123211ad..fe25b4923 100644 --- a/tests/components_to_test/gpt.py +++ b/tests/components_to_test/gpt2.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn -from .registry import non_distributed_component_funcs from transformers import GPT2Config, GPT2LMHeadModel -from .utils.dummy_data_generator import DummyDataGenerator + from colossalai.utils.cuda import get_current_device +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + class DummyDataLoader(DummyDataGenerator): vocab_size = 128 @@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator): input_ids = torch.randint(0, DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), device=get_current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask + return input_ids, input_ids class GPTLMModel(nn.Module): @@ -43,8 +44,9 @@ class GPTLMModel(nn.Module): if checkpoint: self.model.gradient_checkpointing_enable() - def forward(self, input_ids, attention_mask): + def forward(self, input_ids): # Only return lm_logits + attention_mask = torch.ones_like(input_ids) return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] diff --git a/tests/components_to_test/no_leaf_module.py b/tests/components_to_test/hanging_param_model.py similarity index 76% rename from tests/components_to_test/no_leaf_module.py rename to tests/components_to_test/hanging_param_model.py index 28a212f96..329a08ea2 100644 --- a/tests/components_to_test/no_leaf_module.py +++ b/tests/components_to_test/hanging_param_model.py @@ -1,46 +1,49 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai.nn import CheckpointModule -from .utils.dummy_data_generator import DummyDataGenerator -from .registry import non_distributed_component_funcs - - -class NoLeafModule(CheckpointModule): - """ - In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.proj1 = nn.Linear(4, 8) - self.weight = nn.Parameter(torch.randn(8, 8)) - self.proj2 = nn.Linear(8, 4) - - def forward(self, x): - x = self.proj1(x) - x = F.linear(x, self.weight) - x = self.proj2(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - - def generate(self): - data = torch.rand(16, 4) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name='no_leaf_module') -def get_training_components(): - - def model_builder(checkpoint=True): - return NoLeafModule(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - return model_builder, trainloader, testloader, HybridAdam, criterion +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class HangingParamModule(CheckpointModule): + """ + Hanging Parameter: a parameter dose not belong to a leaf Module. + It has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='hanging_param_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return HangingParamModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py new file mode 100644 index 000000000..f061d48f9 --- /dev/null +++ b/tests/components_to_test/inline_op_model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class InlineOpModule(CheckpointModule): + """ + a module with inline Ops + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.proj2 = nn.Linear(8, 8) + + def forward(self, x): + + x = self.proj1(x) + # inline add_ + x.add_(10) + x = self.proj2(x) + # inline relu_ + x = torch.relu_(x) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='inline_op_model') +def get_training_components(): + + def model_builder(checkpoint=False): + return InlineOpModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 26bfb8ecc..339084639 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -1,9 +1,11 @@ import torch import torch.nn as nn import torch.nn.functional as F + from colossalai.nn import CheckpointModule -from .utils import DummyDataGenerator + from .registry import non_distributed_component_funcs +from .utils import DummyDataGenerator class SubNet(nn.Module): @@ -43,7 +45,7 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='nested_model') def get_training_components(): - def model_builder(checkpoint=True): + def model_builder(checkpoint=False): return NestedNet(checkpoint) trainloader = DummyDataLoader() diff --git a/tests/components_to_test/repeated_computed_layer.py b/tests/components_to_test/repeated_computed_layers.py similarity index 96% rename from tests/components_to_test/repeated_computed_layer.py rename to tests/components_to_test/repeated_computed_layers.py index f70910191..b3f84bd0e 100644 --- a/tests/components_to_test/repeated_computed_layer.py +++ b/tests/components_to_test/repeated_computed_layers.py @@ -2,9 +2,11 @@ import torch import torch.nn as nn + from colossalai.nn import CheckpointModule -from .utils.dummy_data_generator import DummyDataGenerator + from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator class NetWithRepeatedlyComputedLayers(CheckpointModule): @@ -37,7 +39,7 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='repeated_computed_layers') def get_training_components(): - def model_builder(checkpoint=True): + def model_builder(checkpoint=False): return NetWithRepeatedlyComputedLayers(checkpoint) trainloader = DummyDataLoader() diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index fd4988d9e..cd9d7ebc0 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -1,10 +1,13 @@ import torch import torch.nn as nn + from colossalai.nn import CheckpointModule -from .utils.dummy_data_generator import DummyDataGenerator -from .registry import non_distributed_component_funcs from colossalai.utils.cuda import get_current_device +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + class SimpleNet(CheckpointModule): """ In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. @@ -29,7 +32,6 @@ class SimpleNet(CheckpointModule): return x - class DummyDataLoader(DummyDataGenerator): def generate(self): @@ -41,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='simple_net') def get_training_components(): - def model_builder(checkpoint=True): + def model_builder(checkpoint=False): return SimpleNet(checkpoint) trainloader = DummyDataLoader() diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py index fc6321214..f223f7d32 100644 --- a/tests/components_to_test/utils/__init__.py +++ b/tests/components_to_test/utils/__init__.py @@ -1 +1,2 @@ -from .dummy_data_generator import DummyDataGenerator +from .dummy_data_generator import DummyDataGenerator +from .executor import run_fwd_bwd diff --git a/tests/components_to_test/utils/executor.py b/tests/components_to_test/utils/executor.py new file mode 100644 index 000000000..e77152561 --- /dev/null +++ b/tests/components_to_test/utils/executor.py @@ -0,0 +1,29 @@ +import torch + + +def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: + """run_fwd_bwd + run fwd and bwd for the model + + Args: + model (torch.nn.Module): a PyTorch model + data (torch.Tensor): input data + label (torch.Tensor): label + criterion (Optional[Callable]): a function of criterion + + Returns: + torch.Tensor: loss of fwd + """ + if criterion: + y = model(data) + y = y.float() + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if optimizer: + optimizer.backward(loss) + else: + loss.backward() + return loss diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index 95c5686ae..7f6f0c86a 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -1,18 +1,16 @@ +import copy +from functools import partial + +import pytest import torch -import colossalai import torch.multiprocessing as mp -from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp -from tests.components_to_test.registry import non_distributed_component_funcs + +import colossalai +from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp - from tests.components_to_test.registry import non_distributed_component_funcs -import copy -import pytest -from functools import partial - def check_equal(a, b): """ @@ -23,7 +21,7 @@ def check_equal(a, b): def run_naive_amp(): """ - In this test, we compare the naive fp16 optimizer implemented in colossalai + In this test, we compare the naive fp16 optimizer implemented in colossalai and fp32 torch optimizer """ @@ -41,11 +39,12 @@ def run_naive_amp(): apex_amp_model = copy.deepcopy(naive_amp_model) # create optimizer - naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3) - apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) + # we use SGD here, since the correctness of gradient clipping can't be tested with Adam + naive_amp_optimizer = torch.optim.SGD(naive_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3) # inject naive and apex amp - naive_amp_config = dict(initial_scale=128) + naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0) naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, naive_amp_config) apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) @@ -62,13 +61,17 @@ def run_naive_amp(): assert_close_loose(naive_amp_output, apex_amp_output) # backward - naive_amp_optimizer.backward(naive_amp_output.mean()) - apex_amp_optimizer.backward(apex_amp_output.mean()) + # use sum() to get big gradient + naive_amp_optimizer.backward(naive_amp_output.sum()) + apex_amp_optimizer.backward(apex_amp_output.sum()) # check grad for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): assert_close_loose(naive_amp_param.grad, apex_amp_param.grad) + # clip gradient + apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) + # step naive_amp_optimizer.step() apex_amp_optimizer.step() diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_amp/test_torch_fp16.py new file mode 100644 index 000000000..e65dd8cde --- /dev/null +++ b/tests/test_amp/test_torch_fp16.py @@ -0,0 +1,97 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp +from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.components_to_test.registry import non_distributed_component_funcs + + +def run_torch_amp(): + """ + In this test, we compare the torch amp and apex amp implemented in colossalai + """ + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # create layer + test_models = ['resnet18', 'simple_net'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_class, _ = get_component_func() + + # create model + torch_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(torch_amp_model) + + # create optimizer + # we use SGD here, since the correctness of gradient clipping can't be tested with Adam + torch_amp_optimizer = torch.optim.SGD(torch_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3) + + # inject torch and apex amp + torch_amp_config = dict(init_scale=128, enabled=True) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, + torch_amp_optimizer, + amp_config=torch_amp_config) + apex_amp_config = dict(opt_level='O1', loss_scale=128) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + torch_amp_output = torch_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(torch_amp_output, apex_amp_output) + + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param, apex_amp_param) + + # backward + # use sum() to get big gradient + torch_amp_optimizer.backward(torch_amp_output.sum()) + apex_amp_optimizer.backward(apex_amp_output.sum()) + + # check grad + # In apex amp, grad is not scaled before backward, but torch amp does + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + + # clip gradient + apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) + torch_amp_optimizer.clip_grad_norm(model=torch_amp_model, max_norm=1.0) + + # step + torch_amp_optimizer.step() + apex_amp_optimizer.step() + + # check updated param and grad + for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad) + assert_close_loose(torch_amp_param, apex_amp_param) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_torch_amp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_torch_amp(): + world_size = 1 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_torch_amp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py new file mode 100644 index 000000000..e666cb175 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -0,0 +1,172 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def check_linear_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel(4, 8).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +def check_conv_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel(3, 6, 2).cuda() + input = torch.rand(4, 3, 64, 64).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bias_addition_module(): + world_size = 4 + run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_linear, nprocs=world_size) + run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_conv, nprocs=world_size) + + +if __name__ == '__main__': + test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 4c35e7de5..560758749 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -1,7 +1,10 @@ import torch -from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable, - recover_sharding_spec_for_broadcast_shape) +from colossalai.auto_parallel.tensor_shard.utils import ( + get_broadcast_shape, + is_broadcastable, + recover_sharding_spec_for_broadcast_shape, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec @@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape(): 1: [1] }, entire_shape=broadcast_shape) - physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1, - broadcast_shape, x1.shape) + physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec_for_x1, broadcast_shape, x1.shape) print(physical_sharding_spec_for_x1) assert physical_sharding_spec_for_x1.entire_shape == x1.shape diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py index a244329c0..96d96a459 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py @@ -1,15 +1,16 @@ +from copy import deepcopy from pickletools import optimize -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from copy import deepcopy +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer class ConvModel(nn.Module): @@ -67,7 +68,8 @@ def test_cost_graph(): for node in graph.nodes: if node.op == 'output': continue - all_node_pairs.append((node, node.next)) + for child in node.users.keys(): + all_node_pairs.append((node, child)) for node_pair in all_node_pairs: assert node_pair in cost_graph.edge_costs @@ -75,14 +77,14 @@ def test_cost_graph(): # construct merged node pairs merged_node_pairs = [] node_list = list(graph.nodes) - - # add (x, conv) and (conv, output) into check node pairs - merged_node_pairs.append((node_list[0], node_list[2])) - merged_node_pairs.append((node_list[2], node_list[-1])) - # (conv1, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 246019.30000000002, (5, 0): 246019.30000000002, (6, 0): 123009.1, (7, 0): 123009.1, (8, 0): 123009.1, (9, 0): 123009.1, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): 246019.30000000002, (14, 0): 246019.30000000002} - # (x, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002} + # add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs + merged_node_pairs.append((node_list[0], node_list[4])) + merged_node_pairs.append((node_list[2], node_list[4])) + merged_node_pairs.append((node_list[3], node_list[5])) + merged_node_pairs.append((node_list[5], node_list[6])) + merged_node_pairs.append((node_list[4], node_list[6])) + merged_node_pairs.append((node_list[6], node_list[-1])) cost_graph.simplify_graph() - for node_pair in all_node_pairs: if node_pair in merged_node_pairs: assert node_pair in cost_graph.edge_costs diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py index 09afbdef1..9342e06a0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py @@ -1,14 +1,16 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.device.device_mesh import DeviceMesh class ConvModel(nn.Module): @@ -37,52 +39,22 @@ def test_conv_handler(): # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # return add graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - # [x, mul, conv, output] - nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - # find the sharding strategies for the input node of the conv node - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(nodes[1]) - sharding_option = (None, 0, 1) - for first_sharding_index in sharding_option: - for second_sharding_index in sharding_option: - if first_sharding_index is not None and second_sharding_index == first_sharding_index: - continue - if first_sharding_index is None: - first_dim_spec = _DimSpec([]) - else: - first_dim_spec = _DimSpec([first_sharding_index]) - - if second_sharding_index is None: - second_dim_spec = _DimSpec([]) - else: - second_dim_spec = _DimSpec([second_sharding_index]) - - replica_dim_spec = _DimSpec([]) - sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec] - sharding_spec = ShardingSpec(device_mesh=device_mesh, - entire_shape=entire_shape, - sharding_sequence=sharding_sequence) - strategy_name = str(sharding_spec.sharding_sequence) - sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) - strategies_vector_for_input.append(sharding_strategy) - setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - - # generate conv strategy - strategies_vector = StrategiesVector(node=nodes[2]) - conv_handler = ConvHandler( - node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - ) - conv_handler.register_strategy() + strategies_constructor.build_strategies_and_cost() + conv_node = list(graph.nodes)[4] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] - strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector] + strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector] # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py index e901b84a3..0a2dba161 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py @@ -1,14 +1,16 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.device.device_mesh import DeviceMesh class LinearModel(nn.Module): @@ -23,6 +25,7 @@ class LinearModel(nn.Module): return x +@pytest.mark.skip('F.linear is not supported in deprecated handler') def test_dot_handler(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -37,52 +40,23 @@ def test_dot_handler(): # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # return add graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - # [x, mul, linear, output] - nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - # find the sharding strategies for the input node of the conv node - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(node=nodes[1]) - sharding_option = (None, 0, 1) - for first_sharding_index in sharding_option: - for second_sharding_index in sharding_option: - if first_sharding_index is not None and second_sharding_index == first_sharding_index: - continue - if first_sharding_index is None: - first_dim_spec = _DimSpec([]) - else: - first_dim_spec = _DimSpec([first_sharding_index]) - - if second_sharding_index is None: - second_dim_spec = _DimSpec([]) - else: - second_dim_spec = _DimSpec([second_sharding_index]) - - sharding_sequence = [first_dim_spec, second_dim_spec] - sharding_spec = ShardingSpec(device_mesh=device_mesh, - entire_shape=entire_shape, - sharding_sequence=sharding_sequence) - strategy_name = str(sharding_spec.sharding_sequence) - sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) - strategies_vector_for_input.append(sharding_strategy) - setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - - # generate dot strategy - strategies_vector = StrategiesVector(node=nodes[2]) - dot_handler = DotHandler( - node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - ) - strategies_vector = dot_handler.register_strategy() + strategies_constructor.build_strategies_and_cost() + linear_node = list(graph.nodes)[4] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] - strategy_name_list = [strategy.name for strategy in strategies_vector] + strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector] # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py index c895dff4e..ac9df4cd8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py @@ -1,12 +1,11 @@ import torch -from torch.fx import GraphModule import torch.nn as nn -import pytest +from torch.fx import GraphModule from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer class ConvModel(nn.Module): @@ -33,7 +32,12 @@ def test_conv_handler(): input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {}) # return flatten graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -44,10 +48,10 @@ def test_conv_handler(): strategies_constructor.build_strategies_and_cost() strategy_map = strategies_constructor.strategy_map - conv_strategies = strategy_map[nodes[1]] - flatten_strategies = strategy_map[nodes[2]] + add_strategies = strategy_map[nodes[5]] + flatten_strategies = strategy_map[nodes[6]] flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies] - for strategy in conv_strategies: + for strategy in add_strategies: assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py index 7886de5ad..9be1a5d96 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py @@ -1,17 +1,18 @@ -import torch -from torch.fx import GraphModule -import torch.nn as nn -import pytest +from copy import deepcopy +import pytest +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions -from copy import deepcopy class ConvModel(nn.Module): @@ -40,9 +41,14 @@ def test_strategies_constructor(): # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # return add graph = tracer.trace(root=model, meta_args=input_sample) + print(graph) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() @@ -63,12 +69,12 @@ def test_strategies_constructor(): # Third node is conv. conv_check_list = deepcopy(CONV_STRATEGIES_LIST) - for strategy in strategies_constructor.leaf_strategies[2]: + for strategy in strategies_constructor.leaf_strategies[4]: conv_check_list.remove(strategy.name) assert len(conv_check_list) == 0 # In fast mode, output node only has replica strategy. - assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output' + assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output' # check strategy_map @@ -81,15 +87,15 @@ def test_strategies_constructor(): mul = nodes[1] assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0' - # Third node is conv. - conv = nodes[2] + # fifth node is conv. + conv = nodes[4] conv_check_list = deepcopy(CONV_STRATEGIES_LIST) for strategy in strategies_constructor.strategy_map[conv]: conv_check_list.remove(strategy.name) assert len(conv_check_list) == 0 # In fast mode, output node only has replica strategy. - output = nodes[3] + output = nodes[-1] assert strategies_constructor.strategy_map[output][0].name == 'Replica Output' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py new file mode 100644 index 000000000..22a237131 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -0,0 +1,279 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + # We temporarily banned the Dropout layer because the rng state need + # to process to get the correct result. + # self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + # TODO: the rng state need to be fixed for distributed runtime + # hidden_states = self.dropout(hidden_states) + return hidden_states + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + # attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + + # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + qkv = self.c_attn(hidden_states) + + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + # attn_output = self.resid_dropout(attn_output) + return attn_output + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + # %transformer_h_0_ln_1 + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # residual connection + hidden_states = attn_outputs + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + # add_2 + hidden_states = inputs_embeds + position_embeds + + # comment to run pipeline + # add_3 + output_shape = input_shape + (hidden_states.size(-1),) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs + + hidden_states = self.ln_f(hidden_states) + # comment to run pipeline + hidden_states = hidden_states.view(output_shape) + + return hidden_states + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + lm_logits = self.lm_head(transformer_outputs) + + return lm_logits + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py new file mode 100644 index 000000000..0979d8353 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py @@ -0,0 +1,131 @@ +import copy +import random +from functools import partial +from time import time +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import psutil +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from torch.fx import GraphModule +from torch.profiler import ProfilerActivity, profile, record_function, schedule, tensorboard_trace_handler + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch, launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss + +BATCH_SIZE = 32 +SEQ_LENGTH = 256 +HIDDEN_DIM = 16384 +NUM_HEADS = 128 +NUM_LAYERS = 4 +VOCAB_SIZE = 50257 +NUM_STEPS = 10 +FP16 = True + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def main(): + disable_existing_loggers() + launch_from_torch(config={}) + logger = get_dist_logger() + config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) + if FP16: + model = GPT2LMHeadModel(config=config).half().to('cuda') + else: + model = GPT2LMHeadModel(config=config).to('cuda') + global_numel = sum([p.numel() for p in model.parameters()]) + + meta_input_sample = { + 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + } + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + gm = initialize_model(model, meta_input_sample, device_mesh) + + # build criterion + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) + torch.cuda.synchronize() + model.train() + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # schedule=schedule(wait=1, warmup=2, active=2), + # on_trace_ready=tensorboard_trace_handler(f'log/dummy_data/bs128_seq128_new'), + # record_shapes=True, + # profile_memory=True) as prof: + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as prof: + for n in range(10): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = gm(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + loss.backward() + optimizer.step() + # prof.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py new file mode 100644 index 000000000..c7f9988f1 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -0,0 +1,207 @@ +import copy +import random +from functools import partial +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + +seed = 128 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec]): + for name, param in module.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + atoms = name.split('.') + new_name = '_'.join(atoms) + if new_name in best_sharding_spec_dict: + param_sharding_spec = best_sharding_spec_dict[new_name] + grad_to_compare = copy.deepcopy(param_grad) + param_grad_global = to_global(grad_to_compare, param_sharding_spec) + + try: + assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03) + except: + difference = param_grad_global - origin_param_grad + avg_diff = difference.abs().sum() / difference.numel() + assert avg_diff < 0.001 + print(f'{name} param has {avg_diff} average difference') + + +def check_attention_layer(rank, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM) + + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + else: + model = model_cls(config=config).to('cuda') + test_model = copy.deepcopy(model) + + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) + + if model_cls == GPT2MLP: + input_sample = (hidden_states.to('cuda'),) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = ( + hidden_states.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + else: + input_sample = ( + input_ids.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'input_ids': input_ids.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + + graph = tracer.trace(root=model, meta_args=meta_input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + best_sharding_spec_dict = {} + for index, node in enumerate(nodes): + best_sharding_spec_dict[node.name] = node.sharding_spec + + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + origin_output = test_model(*test_input_sample) + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(cpu_rng_state) + output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, origin_output, rtol=1e-03, atol=1e-03) + + #*******************backward starting******************* + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + output.sum().backward() + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + origin_output.sum().backward() + origin_param_dict = dict(test_model.named_parameters()) + + if rank == 0: + print("*******************backward starting*******************") + + _check_module_grad(gm, origin_param_dict, best_sharding_spec_dict) + + if rank == 0: + print("*******************backward finished*******************") + + #*******************backward finished******************* + + #*******************strategy selected******************* + if rank == 0: + print("*******************strategy selected*******************") + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) +@rerun_if_address_is_in_use() +def test_mlp_layer(model_cls): + world_size = 4 + run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py new file mode 100644 index 000000000..26ad0d3a0 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import transformers +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +def test_self_attention_block(model_cls): + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config) + else: + model = model_cls(config=config) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + if model_cls == GPT2MLP: + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + } + else: + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + input_sample = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + print(gm.graph) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py new file mode 100644 index 000000000..f468b1ab2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -0,0 +1,61 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _ReLU_module_mem_test(rank, world_size, port): + """This function is for ReLU memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.ReLU()).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 1 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_ReLU_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_ReLU_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py new file mode 100644 index 000000000..826c74666 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py @@ -0,0 +1,60 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _batchnorm_module_mem_test(rank, world_size, port): + """This function is for batchnorm memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(128)).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_batchnorm_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_batchnorm_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py new file mode 100644 index 000000000..1b745d890 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -0,0 +1,71 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class BinaryElementwiseOpModule(nn.Module): + + def __init__(self, token=torch.add, shape=64) -> None: + super().__init__() + self.token = token + self.param = nn.Parameter(torch.rand(shape)) + + def forward(self, input): + return input + self.param + + +def _binary_elementwise_mem_test(rank, world_size, port): + """This function is for binary elementwise ops memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() + input = torch.rand(32, 1024).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py new file mode 100644 index 000000000..a973a8182 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -0,0 +1,113 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +class ConvFunctionModule(nn.Module): + + def __init__(self, in_channels=4, out_channels=64, kernel_size=3): + super().__init__() + self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + def forward(self, input): + return nn.functional.conv2d(input, self.conv_weight) + + +def _conv_module_mem_test(rank, bias, world_size, port): + """This function is for conv memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_meta_concrete_info_match(bias=False): + world_size = 4 + run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _conv_function_mem_test(rank, world_size, port): + """This function is for conv function memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvFunctionModule().cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_function_concrete_info_match(): + world_size = 4 + run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + # test_conv_meta_concrete_info_match() + test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py new file mode 100644 index 000000000..e9c0601eb --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -0,0 +1,111 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +class MyModule(nn.Module): + + def __init__(self, in_features=64, out_features=128): + super().__init__() + self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) + + def forward(self, input): + return nn.functional.linear(input, self.fc_weight) + + +def _linear_module_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_module_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _linear_function_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MyModule().cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=24, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_function_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + # test_linear_module_meta_concrete_info_match() + test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py new file mode 100644 index 000000000..529686d27 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -0,0 +1,102 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + + +def _adaptiveavgpool_module_mem_test(rank, world_size, port): + """This function is for AdaptiveAvgPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target strategies + strategy_number = 1 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_adaptiveavgpool_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _maxpool_module_mem_test(rank, world_size, port): + """This function is for MaxPool memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() + input = torch.rand(4, 128, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 1 + # total number of target node strategies + strategy_number = 9 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_maxpool_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_adaptiveavgpool_meta_concrete_info_match() + test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py new file mode 100644 index 000000000..7c06f2ee9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -0,0 +1,128 @@ +import copy +from pprint import pprint +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo + + +def mem_test_for_node_strategy(rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}): + for strategy_index in range(strategy_number): + # We need to copy the model to avoid do backward more than once in same graph + model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( + input_kwargs) + + tracer = ColoTracer() + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + + # solution construction + # construct the strategy for the target node + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + + # construct the strategy for the output node + placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] + + output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT) + placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ + output_key] + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + gm: GraphModule + + num_of_strategies = len(target_node.strategies_vector) + if rank == 0: + print("=======================") + print(f"#strategy_index: {strategy_index + 1}/{num_of_strategies}") + pprint(target_node.strategies_vector[strategy_index]) + + # warmup + with torch.no_grad(): + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + del output + # forward memory compare + if rank == 0: + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + + if rank == 0: + # print forward memory allocated and peak memory stats in kb + print( + f"forward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # backward memory compare + grad_tensors = torch.ones_like(output) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output, grad_tensors) + + if rank == 0: + # print backward memory allocated and peak memory stats in kb + print( + f"backward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb" + ) + + # estimated memory + if target_node.op == "call_module": + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) + else: + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + + print("estimated memory:") + print( + f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb" + ) + print( + f"forward temp: {metainfo.memory_cost.fwd.temp / 1024} kb, forward buffer: {metainfo.memory_cost.fwd.buffer / 1024} kb" + ) + print( + f"backward activation: {metainfo.memory_cost.bwd.activation / 1024} kb, backward param: {metainfo.memory_cost.bwd.parameter / 1024} kb" + ) + print( + f"backward temp: {metainfo.memory_cost.bwd.temp / 1024} kb, backward buffer: {metainfo.memory_cost.bwd.buffer / 1024} kb" + ) + print("=======================") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index 54cd473b4..ffc15e403 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -1,48 +1,96 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing import parameterize +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy class AddBMMTensorMethodModule(nn.Module): + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + def forward(self, bias, x1, x2): - return bias.addbmm(x1, x2) + if self.using_kwargs: + output = bias.addbmm(x1, x2, alpha=2, beta=3) + else: + output = bias.addbmm(x1, x2) + return output class AddBMMTorchFunctionModule(nn.Module): + def __init__(self, using_kwargs): + super().__init__() + self.using_kwargs = using_kwargs + def forward(self, bias, x1, x2): - return torch.addbmm(bias, x1, x2) + if self.using_kwargs: + output = torch.addbmm(bias, x1, x2, alpha=2, beta=3) + else: + output = torch.addbmm(bias, x1, x2) + return output -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -def test_2d_device_mesh(module, bias_shape): - - model = module() +def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module(using_kwargs).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add graph = tracer.trace(model, meta_args={ 'bias': torch.rand(*bias_shape).to('meta'), "x1": torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta') }) - print(graph) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - linear_mod_node = list(graph.nodes)[3] - strategies_vector = StrategiesVector(linear_mod_node) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) # build handler - handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -65,20 +113,15 @@ def test_2d_device_mesh(module, bias_shape): assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size(bias_shape) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([8, 8]) - - assert mapping['output'].name == "addbmm" + assert mapping['output'].name == "bmm" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - + for name in strategy_name_list: + print(name) # one batch dim assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list @@ -100,38 +143,60 @@ def test_2d_device_mesh(module, bias_shape): for strategy in strategies_vector: input_sharding_spec = strategy.get_sharding_spec_by_name('x1') other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') - output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data - assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -def test_1d_device_mesh(module, bias_shape): - model = module() +def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + model = module(using_kwargs).cuda() + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + bias = torch.rand(bias_shape).cuda() + # the index of addbmm node in computation graph + node_index = 3 + # strategy number of addbmm node on 2d device mesh + strategy_number = 1 + # construct input args + input_args = [bias, x1, x2] + # construct meta arg names + meta_arg_names = ['bias', 'x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + # graph(): + # %bias : torch.Tensor [#users=1] = placeholder[target=bias] + # %x1 : torch.Tensor [#users=1] = placeholder[target=x1] + # %x2 : torch.Tensor [#users=1] = placeholder[target=x2] + # %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) + # return add graph = tracer.trace(model, meta_args={ 'bias': torch.rand(*bias_shape).to('meta'), "x1": torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta') }) - print(graph) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (1, 4) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - linear_mod_node = list(graph.nodes)[3] - strategies_vector = StrategiesVector(linear_mod_node) + bmm_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(bmm_mod_node) # build handler - handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -154,15 +219,9 @@ def test_1d_device_mesh(module, bias_shape): assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size(bias_shape) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([8, 8]) - - assert mapping['output'].name == "addbmm" + assert mapping['output'].name == "bmm" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].data.shape == torch.Size([4, 8, 8]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) @@ -174,16 +233,50 @@ def test_1d_device_mesh(module, bias_shape): for strategy in strategies_vector: input_sharding_spec = strategy.get_sharding_spec_by_name('x1') other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') - output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data - assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_2d_device_mesh(module, bias_shape, using_kwargs): + world_size = 4 + run_func = partial(check_2d_device_mesh, + module=module, + bias_shape=bias_shape, + world_size=world_size, + using_kwargs=using_kwargs, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@pytest.mark.skip("skip due to bias cases not ready") +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +@parameterize('using_kwargs', [True, False]) +@rerun_if_address_is_in_use() +def test_1d_device_mesh(module, bias_shape, using_kwargs): + world_size = 4 + run_func = partial(check_1d_device_mesh, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_1d_device_mesh() - # test_2d_device_mesh() + test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py new file mode 100644 index 000000000..aa5a57474 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -0,0 +1,203 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddmmModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, m1, m2): + x = torch.addmm(input, m1, m2, beta=3, alpha=2) + return x + + +class AddmmModel_with_param(nn.Module): + + def __init__(self, weight_shape, bias_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(bias_shape)) + + def forward(self, m1): + x = torch.addmm(self.bias, m1, self.weight, beta=3, alpha=2) + return x + + +def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if model_cls == AddmmModel: + model = AddmmModel().cuda() + else: + model = AddmmModel_with_param(weight_shape=(8, 16), bias_shape=input_shape).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + if model_cls == AddmmModel: + input = torch.rand(input_shape).cuda() + m1 = torch.rand(4, 8).cuda() + m2 = torch.rand(8, 16).cuda() + # construct input args + input_args = [input, m1, m2] + # construct meta arg names + meta_arg_names = ['input', 'm1', 'm2'] + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + + # the index of addmm node in computation graph + node_index = 4 + # strategy number of linear node + strategy_number = 14 + else: + m1 = torch.rand(4, 8).cuda() + # construct input args + input_args = [m1] + # construct meta arg names + meta_arg_names = ['m1'] + # the index of addmm node in computation graph + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + node_index = 4 + # strategy number of linear node + strategy_number = 14 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %m1 : torch.Tensor [#users=1] = placeholder[target=m1] + # %m2 : torch.Tensor [#users=1] = placeholder[target=m2] + # %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {}) + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {}) + # %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args=meta_args_for_tracer) + gm = ColoGraphModule(model, graph) + # [input_1, m1, m2, addmm, output] + node_list = list(graph.nodes) + linear_node = node_list[4] + strategies_vector = StrategiesVector(linear_node) + + # build handler + handler = LinearFunctionHandler(node=linear_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "m1" + assert mapping['input'].data.shape == torch.Size([4, 8]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8]) + + assert mapping['other'].name == "transpose" + assert mapping['other'].data.shape == torch.Size([16, 8]) + if model_cls == AddmmModel: + assert mapping['other'].type == OperationDataType.ARG + else: + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([8, 16]) + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('m1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('input_shape', [(16,), (4, 16)]) +@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) +@rerun_if_address_is_in_use() +def test_addmm_handler(input_shape, model_cls): + world_size = 4 + run_func_function = partial(check_addmm_function_handler, + input_shape=input_shape, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_function, nprocs=world_size) + + +if __name__ == '__main__': + test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index e6ab63a12..0ab70abff 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -1,18 +1,43 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \ - BatchNormModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.fx.tracer.meta_patch.patched_module import linear -import pytest +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -@pytest.mark.skip("skip due to passes not ready") -def test_bn_module_handler(): - model = nn.Sequential(nn.BatchNorm2d(16).to('meta')) +def check_bn_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.BatchNorm2d(16)).cuda() + + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 64, 64).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of bn strategies without sync bn mode + # TODO: add sync bn stategies after related passes ready + strategy_number = 4 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -20,10 +45,6 @@ def test_bn_module_handler(): # return _0 graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')}) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) bn_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(bn_mod_node) @@ -40,25 +61,21 @@ def test_bn_module_handler(): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([16]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([16]) assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) assert mapping['output'].type == OperationDataType.OUTPUT @@ -75,16 +92,27 @@ def test_bn_module_handler(): # RS01 = RS01 x S01 assert 'RS01 = RS01 x S01' in strategy_name_list + # temporarily skip the sync bn test + # TODO: test sync bn after the implicit runtime pass completed # SR = SR x R WITH SYNC_BN - assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list - assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + # assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + # assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list # SS = SS x S WITH SYNC_BN - assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list - assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + # assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + # assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list # S01R = S01R x R WITH SYNC_BN - assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bn_module_handler(): + world_size = 4 + run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py new file mode 100644 index 000000000..162d1fbba --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -0,0 +1,177 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +WEIGHT_SHAPE = (32, 16) + + +class LinearModule(torch.nn.Module): + + def __init__(self, weight_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(weight_shape[0])) + + def forward(self, x): + x = F.linear(x, self.weight, bias=self.bias) + return x + + +def check_linear_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %weight : [#users=1] = get_attr[target=weight] + # %bias : [#users=1] = get_attr[target=bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(): + world_size = 4 + run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py new file mode 100644 index 000000000..c5c3f3781 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -0,0 +1,166 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearModule(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +def check_linear_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(16, 32, bias=bias).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([64, 16]) + + assert mapping['other'].name == "linear_weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(bias=True): + world_size = 4 + run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py new file mode 100644 index 000000000..42430d5a2 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -0,0 +1,232 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, x1, x2): + out = self.op(x1, x2) + return out + + model = BinaryElementwiseOpModel(op).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 4).cuda() + x2 = torch.rand([4] * other_dim).cuda() + # the index of binary-elementwise node in computation graph + node_index = 2 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + op_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4] * other_dim) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + # since the dim of the other can change, we make sure at least its last dim sharding is the same + if len(other_sharding_spec.sharding_sequence) == 2: + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + elif len(other_sharding_spec.sharding_sequence) == 1: + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + + +def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + class BinaryElementwiseOpModel(nn.Module): + + def __init__(self, op, const): + super().__init__() + self.op = op + self.const = const + + def forward(self, x1): + out = self.op(x1, self.const) + return out + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + model = BinaryElementwiseOpModel(op, other_dim).cuda() + x1 = torch.rand(4, 4).cuda() + # the index of binary-elementwise node in computation graph + node_index = 1 + # strategy number of binary-elementwise node + strategy_number = 9 + # construct input args + input_args = [x1] + # construct meta arg names + meta_arg_names = ['x1'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + tracer = ColoTracer() + meta_args = {'x1': torch.rand(4, 4).to('meta')} + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + op_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(op_node) + + # build handler + handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4]) + + assert mapping['output'].name == str(op_node) + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([4, 4]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) == 9 + + # check if the sharding strategy is correct + assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list + assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list + assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list + assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list + assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list + assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list + assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list + assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list + assert '[R, R] = [R, R] [R, R]' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) + + # make sure the sharding spec is the same for input and output + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('op', [torch.add]) +@parameterize('other_dim', [1, 2]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_binary_elementwise_handler(op, other_dim): + world_size = 4 + run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, + op=op, + other_dim=other_dim, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_tensor, nprocs=world_size) + run_func_int = partial(check_binary_elementwise_handler_with_int, + op=op, + other_dim=other_dim, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_int, nprocs=world_size) + + +if __name__ == '__main__': + test_binary_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index f59fea90d..02c7e0671 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -1,12 +1,20 @@ +from functools import partial + import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing import parameterize +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy class BMMTensorMethodModule(nn.Module): @@ -21,22 +29,37 @@ class BMMTorchFunctionModule(nn.Module): return torch.bmm(x1, x2) -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) -def test_2d_device_mesh(module): - - model = module() +def check_2d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 2d device mesh + strategy_number = 7 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) tracer = ColoTracer() graph = tracer.trace(model, meta_args={ "x1": torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta') }) - print(graph) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) linear_mod_node = list(graph.nodes)[2] strategies_vector = StrategiesVector(linear_mod_node) @@ -96,27 +119,41 @@ def test_2d_device_mesh(module): output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data - print(input_sharding_spec.sharding_sequence, output_sharding_spec.sharding_sequence) assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) -def test_1d_device_mesh(module): - model = module() +def check_1d_device_mesh(rank, module, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = module().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + x1 = torch.rand(4, 8, 16).cuda() + x2 = torch.rand(4, 16, 8).cuda() + # the index of bmm node in computation graph + node_index = 2 + # strategy number of bmm node on 1d device mesh + strategy_number = 1 + # construct input args + input_args = [x1, x2] + # construct meta arg names + meta_arg_names = ['x1', 'x2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) tracer = ColoTracer() graph = tracer.trace(model, meta_args={ "x1": torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta') }) - print(graph) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (1, 4) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) linear_mod_node = list(graph.nodes)[2] strategies_vector = StrategiesVector(linear_mod_node) @@ -166,6 +203,17 @@ def test_1d_device_mesh(module): assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bmm_handler(module): + world_size = 4 + run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) + mp.spawn(run_func_2d, nprocs=world_size) + run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port()) + mp.spawn(run_func_1d, nprocs=world_size) + + if __name__ == '__main__': - test_1d_device_mesh() - test_2d_device_mesh() + test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 97025729c..2acd015c8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,27 +1,49 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing import parameterize +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -@parameterize('bias', [True, False]) -def test_conv_module_handler(bias): - model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta')) - tracer = ColoTracer() +def check_conv_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 + input = torch.rand(4, 4, 64, 64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of conv node in computation graph + node_index = 1 + # total number of conv strategies + strategy_number = 16 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) conv_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(conv_mod_node) @@ -38,26 +60,26 @@ def test_conv_module_handler(bias): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta + # assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta + # assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) if bias: assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta + # assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta + # assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) assert mapping['output'].type == OperationDataType.OUTPUT @@ -129,9 +151,33 @@ class ConvModel(nn.Module): return x -@parameterize('bias', [True, False]) -def test_conv_function_handler(bias): - model = ConvModel() +def check_conv_function_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 64, 64).cuda() + others = torch.rand(16, 4, 3, 3).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of conv strategies + strategy_number = 16 + node_index = 2 + if bias: + bias_tensor = torch.rand(16).cuda() + input_kwargs['bias'] = bias_tensor + node_index += 1 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -143,10 +189,6 @@ def test_conv_function_handler(bias): meta_args['bias'] = torch.rand(16).to('meta') graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) if bias: conv_mod_node = list(graph.nodes)[3] @@ -248,6 +290,30 @@ def test_conv_function_handler(bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_module_handler(bias=False): + world_size = 4 + run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +# We temporarily ban the bias option before doing bias add +# before all reduce communication may encounter correctness issue. +# @parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_function_handler(bias=False): + world_size = 4 + run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': test_conv_module_handler() test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py new file mode 100644 index 000000000..5bce383dd --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -0,0 +1,286 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import ( + EmbeddingFunctionHandler, + EmbeddingModuleHandler, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + +NUM_EMBEDDINGS = 16 +EMBEDDING_DIMS = 32 + + +class EmbeddingModule(nn.Module): + + def __init__(self, num_embeddings, embedding_dims): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dims) + + def forward(self, input): + x = self.embedding(input) + return x + + +def check_embedding_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %embedding : [#users=1] = call_module[target=embedding](args = (%input_1,), kwargs = {}) + # return embedding + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of embedding node in computation graph + node_index = 1 + # total number of embedding strategies + strategy_number = 19 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + embedding_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingModuleHandler(node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + # assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +class EmbeddingFunction(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others): + x = nn.functional.embedding(input, others) + return x + + +def check_embedding_function_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = EmbeddingFunction().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS + input = input.to(torch.int64).cuda() + others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of embedding strategies + strategy_number = 19 + node_index = 2 + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) + # return embedding + meta_args = { + "input": torch.rand(4, 16, 16).to('meta'), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + } + graph = tracer.trace(model, meta_args=meta_args) + gm = ColoGraphModule(model, graph) + + embedding_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(embedding_node) + + # build handler + handler = EmbeddingFunctionHandler(node=embedding_node, + device_mesh=device_mesh, + strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([1024]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + + assert mapping['output'].name == "embedding" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # RR = RR x RR + assert 'RR = R x RR' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0 x RR_0' in strategy_name_list + assert 'S0R = S0 x RR_1' in strategy_name_list + assert 'S0R = S0 x RR_2' in strategy_name_list + assert 'S1R = S1 x RR_0' in strategy_name_list + assert 'S1R = S1 x RR_1' in strategy_name_list + assert 'S1R = S1 x RR_2' in strategy_name_list + + # SS = SR x RS + assert 'S0S1 = S0 x RS1_0' in strategy_name_list + assert 'S0S1 = S0 x RS1_1' in strategy_name_list + assert 'S0S1 = S0 x RS1_2' in strategy_name_list + assert 'S1S0 = S1 x RS0_0' in strategy_name_list + assert 'S1S0 = S1 x RS0_1' in strategy_name_list + assert 'S1S0 = S1 x RS0_2' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = R x RS0' in strategy_name_list + assert 'RS1 = R x RS1' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01 x RR_0' in strategy_name_list + assert 'S01R = S01 x RR_1' in strategy_name_list + assert 'S01R = S01 x RR_2' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = R x RS01' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_module_handler(): + world_size = 4 + run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_embedding_function_handler(): + world_size = 4 + run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_module_handler() + test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py new file mode 100644 index 000000000..681e93a5f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +class GetattrModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) + + def forward(self, input): + weight = self.conv.weight + return weight + + +def test_getattr_handler(): + model = GetattrModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # return conv_weight + graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + getattr_node = list(graph.nodes)[1] + getattr_strategies_vector = StrategiesVector(getattr_node) + + # build handler + getattr_handler = GetattrHandler(node=getattr_node, + device_mesh=device_mesh, + strategies_vector=getattr_strategies_vector) + + getattr_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = getattr_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "conv_weight" + assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in getattr_handler.strategies_vector] + assert 'get_attr [S0, S1, R, R]' in strategy_name_list + assert 'get_attr [S1, S0, R, R]' in strategy_name_list + assert 'get_attr [S01, R, R, R]' in strategy_name_list + assert 'get_attr [R, S01, R, R]' in strategy_name_list + assert 'get_attr [S0, R, R, R]' in strategy_name_list + assert 'get_attr [R, S0, R, R]' in strategy_name_list + assert 'get_attr [S1, R, R, R]' in strategy_name_list + assert 'get_attr [R, S1, R, R]' in strategy_name_list + assert 'get_attr [R, R, R, R]' in strategy_name_list + + +if __name__ == '__main__': + test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 5f7c469bc..3c35da61b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,56 +1,83 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -class GetItemModel(nn.Module): +class GetItemFromTensorModel(nn.Module): - def __init__(self): + def __init__(self, getitem_index): super().__init__() + self.getitem_index = getitem_index def forward(self, input, other): - conv_node = nn.functional.conv2d(input, other) - x = conv_node[1] + linear_node = nn.functional.linear(input, other, bias=None) + x = linear_node[self.getitem_index] return x -def test_getitem_function_handler(): - model = GetItemModel() +def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + model = GetItemFromTensorModel(getitem_index=getitem_index) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() - # graph(): - # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] - # %other : torch.Tensor [#users=1] = placeholder[target=other] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) - # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%conv2d, 1), kwargs = {}) - # return getitem + graph = tracer.trace(model, meta_args={ - "input": torch.rand(4, 4, 64, 64).to('meta'), - "other": torch.rand(4, 16, 3, 3).to('meta'), + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), }) - gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - conv_mod_node = list(graph.nodes)[2] + gm = ColoGraphModule(model, graph) + linear_mod_node = list(graph.nodes)[2] getitem_mod_node = list(graph.nodes)[3] getitem_strategies_vector = StrategiesVector(getitem_mod_node) - conv_strategies_vector = StrategiesVector(conv_mod_node) + linear_strategies_vector = StrategiesVector(linear_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) - conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + linear_handler = LinearFunctionHandler(node=linear_mod_node, + device_mesh=device_mesh, + strategies_vector=linear_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) getitem_handler = GetItemHandler(node=getitem_mod_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector) @@ -64,11 +91,88 @@ def test_getitem_function_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(linear_strategies_vector) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) +@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +def test_getitem_from_tensor_handler(getitem_index): + world_size = 4 + run_func = partial(check_getitem_from_tensor_handler, + getitem_index=getitem_index, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +class GetItemFromTupleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + split_node = torch.split(input, 2, 0) + x = split_node[1] + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_getitem_from_tuple_handler(): + model = GetItemFromTupleModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) + # return getitem + graph = tracer.trace(model, meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + input_node = list(graph.nodes)[0] + split_node = list(graph.nodes)[1] + getitem_node = list(graph.nodes)[2] + input_strategies_vector = StrategiesVector(input_node) + getitem_strategies_vector = StrategiesVector(getitem_node) + split_strategies_vector = StrategiesVector(split_node) + + # build handler + input_handler = PlaceholderHandler( + node=input_node, + device_mesh=device_mesh, + strategies_vector=input_strategies_vector, + placeholder_option='replicated', + ) + input_handler.register_strategy(compute_resharding_cost=False) + setattr(input_node, 'strategies_vector', input_strategies_vector) + split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + split_handler.register_strategy(compute_resharding_cost=False) + setattr(split_node, 'strategies_vector', split_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + getitem_handler.register_strategy(compute_resharding_cost=False) + setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "split" assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) assert mapping['index'].name == "index" assert isinstance(mapping['index'].data, int) @@ -76,12 +180,13 @@ def test_getitem_function_handler(): assert mapping['output'].name == "getitem" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 62, 62]) + assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) assert mapping['output'].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. - assert len(getitem_strategies_vector) == len(conv_strategies_vector) + assert len(getitem_strategies_vector) == len(split_strategies_vector) if __name__ == '__main__': - test_getitem_function_handler() + test_getitem_from_tensor_handler() + test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index 1a8487e7e..f4d0063fd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -1,16 +1,45 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \ - LayerNormModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def test_ln_module_handler(): - model = nn.Sequential(nn.LayerNorm(16).to('meta')) +def check_ln_module_handler(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.LayerNorm(16)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 16).cuda() + # the index of bn node in computation graph + node_index = 1 + # the total number of ln strategies + strategy_number = 4 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -18,10 +47,7 @@ def test_ln_module_handler(): # return _0 graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) ln_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(ln_mod_node) @@ -38,25 +64,21 @@ def test_ln_module_handler(): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 16]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 16]) assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([16]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([16]) assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 16]) assert mapping['output'].type == OperationDataType.OUTPUT @@ -74,5 +96,14 @@ def test_ln_module_handler(): assert '[S01, R] = [S01, R] x [R]' in strategy_name_list +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_ln_module_handler(): + world_size = 4 + run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 290d73f5a..3d268ea43 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,4 +1,10 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn from typing_extensions import Self @@ -11,21 +17,45 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -@parameterize('bias', [True, False]) -def test_linear_module_handler(bias): - model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta')) +def check_linear_module_handler(rank, bias, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(input_shape).cuda() + # the index of linear node in computation graph + node_index = 1 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['input'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')}) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - print(graph) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) linear_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(linear_mod_node) @@ -42,42 +72,51 @@ def test_linear_module_handler(bias): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].data.shape == torch.Size(input_shape) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([16, 16]) + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == input_logical_shape assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([32, 16]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([16, 32]) if bias: assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([32]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].logical_shape == torch.Size([32]) assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([16, 32]) + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -91,6 +130,19 @@ def test_linear_module_handler(bias): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') @@ -109,18 +161,51 @@ def test_linear_module_handler(bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@parameterize('bias', [True, False]) -def test_linear_function_handler(bias): - model = nn.Linear(16, 32, bias=bias).to('meta') - tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) - gm = ColoGraphModule(model, graph) +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.linear(input, others, bias=bias) + return x + + +def check_linear_function_handler(rank, bias, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel().cuda() physical_mesh_id = torch.arange(0, 4) - - print(graph) mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(input_shape).cuda() + other = torch.rand(32, 16).cuda() + # the index of linear node in computation graph + node_index = 2 + # strategy number of linear node + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 + # construct input args + input_args = [input, other] + # construct meta arg names + meta_arg_names = ['input', 'others'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(input_shape).to('meta'), + 'others': torch.rand(32, 16).to('meta') + }) + gm = ColoGraphModule(model, graph) if bias: linear_func_node = list(graph.nodes)[3] else: @@ -134,41 +219,51 @@ def test_linear_function_handler(bias): mapping = handler.get_operation_data_mapping() assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].data.shape == torch.Size(input_shape) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([16, 16]) + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == torch.Size(input_logical_shape) - assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta + assert mapping['other'].name == "others" assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([16, 32]) if bias: assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "linear" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -182,10 +277,23 @@ def test_linear_function_handler(bias): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') output_sharding_spec = strategy.get_sharding_spec_by_name('linear') if bias: @@ -200,6 +308,25 @@ def test_linear_function_handler(bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(input_shape, bias=False): + world_size = 4 + run_func_module = partial(check_linear_module_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + run_func_function = partial(check_linear_function_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_function, nprocs=world_size) + + if __name__ == '__main__': - test_linear_module_handler() - test_linear_function_handler() + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py new file mode 100644 index 000000000..306c45f56 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( + MatMulHandler, + MatMulType, + _get_bmm_logical_shape, + get_matmul_type, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.utils import parameterize + + +class MatMulModule(nn.Module): + + def forward(self, x1, x2): + return torch.matmul(x1, x2) + + +@parameterize( + 'tensor_shapes', + [ + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ]) +def test_matmul_node_handler(tensor_shapes): + input_shape, other_shape = tensor_shapes + + # get output shape + x1 = torch.rand(*input_shape) + x2 = torch.rand(*other_shape) + output_shape = list(torch.matmul(x1, x2).shape) + + # get matmul type + matmul_type = get_matmul_type(x1.dim(), x2.dim()) + + model = MatMulModule() + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + mod_node = list(graph.nodes)[2] + strategies_vector = StrategiesVector(mod_node) + + # build handler + handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + logical_input_shape = input_shape + logical_other_shape = other_shape + logical_output_shape = output_shape + if matmul_type == MatMulType.MM and len(input_shape) == 1: + logical_input_shape = [1] + input_shape + elif matmul_type == MatMulType.BMM: + logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( + input_shape, other_shape, handler.transforms) + else: + logical_input_shape = input_shape + + # check input operation data + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size(input_shape) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + + # check other operation data + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size(other_shape) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + + # check output + assert mapping['output'].name == "matmul" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(output_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # ensure there is no duplicate strategy + if matmul_type != MatMulType.BMM: + assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + + if matmul_type == MatMulType.DOT: + # dot product will produce a scaler + # results should fulfill: + # 1. the input and other operands have the same sharding spec + # 2. the output has no sharding + assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence + assert len(output_sharding_spec.sharding_sequence) == 0 + elif matmul_type == MatMulType.MV: + # matrix-vector product should fulfill + # 1. the last dim of the input and other operands should have the same sharding + # 2. the first dim of the input and other should have the same sharding + # 3. the output should have only 1 dim + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert len(output_sharding_spec.sharding_sequence) == 1 + elif matmul_type == MatMulType.MM: + # matrix-matrix multiplication should fulfil + # 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding + # 2. the input's last dim and the first dim of the other should have the same sharding + # 3. the last dim of the output and other should have the same sharding + # 4. the input and output should have the same number of dims + if len(input_shape) == 2: + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0] + assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] + assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence) + elif matmul_type == MatMulType.BMM: + # bmm should fulfil + # 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding + # 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding + # 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding + if len(other_shape) > 1: + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if len(input_shape) > 1: + assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] + if len(other_shape) > 2: + assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] + + +if __name__ == '__main__': + test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index d47876af2..f219bc2f3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -2,15 +2,15 @@ import pytest import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \ - NormPoolingHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing.pytest_wrapper import run_on_environment_flag +@run_on_environment_flag(name='AUTO_PARALLEL') def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 27b0af4fb..26376c429 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \ - OuputHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use class OutputModel(nn.Module): @@ -18,7 +18,9 @@ class OutputModel(nn.Module): return x, y -def test_output_handler(): +@parameterize('output_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_output_handler(output_option): model = OutputModel() tracer = ColoTracer() # graph(): @@ -37,7 +39,10 @@ def test_output_handler(): output_strategies_vector = StrategiesVector(output_node) # build handler - otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector) + otuput_handler = OutputHandler(node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option) otuput_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -49,10 +54,12 @@ def test_output_handler(): assert op_data.data is not None assert mapping['output'].name == "output" - assert mapping['output'].data.is_meta assert mapping['output'].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in otuput_handler.strategies_vector] - assert "Replica Output" in strategy_name_list + if output_option == 'distributed': + assert "Distributed Output" in strategy_name_list + else: + assert "Replica Output" in strategy_name_list if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py new file mode 100644 index 000000000..c695b8843 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -0,0 +1,339 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + # permute_node = torch.permute(conv_node, self.permute_dims) + if self.call_function == torch.permute: + permute_node = self.call_function(conv_node, self.reshape_dims) + else: + permute_node = self.call_function(conv_node, *self.reshape_dims) + return permute_node + + +class LinearReshapeModel(nn.Module): + + def __init__(self, reshape_dims, call_function): + super().__init__() + self.reshape_dims = reshape_dims + self.call_function = call_function + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + # permute_node = torch.permute(linear_node, self.tgt_shape) + if self.call_function == torch.permute: + permute_node = self.call_function(linear_node, self.reshape_dims) + else: + permute_node = self.call_function(linear_node, *self.reshape_dims) + return permute_node + + +def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if call_function == torch.permute: + reshape_dims = reshape_dims[0] + elif call_function == torch.transpose: + reshape_dims = reshape_dims[1] + model = model_cls(reshape_dims, call_function).cuda() + + if model_cls.__name__ == 'ConvReshapeModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearReshapeModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearReshapeModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return permute + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + reshape_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(reshape_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvReshapeModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearReshapeModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if call_function == torch.permute: + reshape_handler = PermuteHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + else: + reshape_handler = TransposeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=view_strategies_vector) + + reshape_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = reshape_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvReshapeModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + if call_function == torch.permute: + assert mapping['output'].name == "permute" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + else: + assert mapping['output'].name == "transpose" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + if rank == 0: + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvReshapeModel': + + if reshape_dims in ((0, 2, 1, 3), (1, 2)): + assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearReshapeModel': + + if reshape_dims == ((0, 2, 1, 3), (1, 2)): + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (2, 0, 1, 3): + assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if reshape_dims == (1, 3): + assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('call_function', [torch.permute, torch.transpose]) +@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +def test_view_handler(call_function, reshape_dims, model_cls): + world_size = 4 + run_func = partial(check_view_handler, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index bdec901e9..9bc453a27 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \ - PlacehodlerHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use class PlaceholderModel(nn.Module): @@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module): return input -def test_placeholder_handler(): +@parameterize('placeholder_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_placeholder_handler(placeholder_option): model = PlaceholderModel() tracer = ColoTracer() # graph(): @@ -33,16 +35,25 @@ def test_placeholder_handler(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) - # build handler - placeholder_handler = PlacehodlerHandler(node=placeholder_node, + placeholder_handler = PlaceholderHandler(node=placeholder_node, device_mesh=device_mesh, - strategies_vector=placeholder_strategies_vector) + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option) placeholder_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping mapping = placeholder_handler.get_operation_data_mapping() + strategy = placeholder_strategies_vector[0] + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + + if placeholder_option == 'distributed': + assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + else: + assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values @@ -53,7 +64,10 @@ def test_placeholder_handler(): assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) assert mapping['output'].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] - assert "Replica Placeholder" in strategy_name_list + if placeholder_option == 'replicated': + assert "Replica Placeholder" in strategy_name_list + else: + assert "Distributed Placeholder" in strategy_name_list if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py index 613f8f3d0..de277002b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -20,6 +20,7 @@ class ReshapeModel(nn.Module): return reshape_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_reshape_handler(): model = ReshapeModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py new file mode 100644 index 000000000..b5e8e3277 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -0,0 +1,186 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSplitModel(nn.Module): + + def __init__(self, softmax_dim): + super().__init__() + self.softmax_dim = softmax_dim + + def forward(self, input, other): + linear_node = F.linear(input, other, bias=None) + softmax_node = F.softmax(linear_node, self.softmax_dim) + return softmax_node + + +def check_split_handler(rank, softmax_dim, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(softmax_dim=softmax_dim).cuda() + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + softmax_handler = SoftmaxHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) + + softmax_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = softmax_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['softmax_dim'].name == "softmax_dim" + assert mapping['softmax_dim'].data == softmax_dim + assert mapping['softmax_dim'].type == OperationDataType.ARG + + assert mapping['output'].name == "softmax" + assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + + if softmax_dim == 0: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if softmax_dim == 1: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('softmax_dim', [0, 1, 2, 3]) +@parameterize('model_cls', [LinearSplitModel]) +def test_split_handler(softmax_dim, model_cls): + world_size = 4 + run_func = partial(check_split_handler, + softmax_dim=softmax_dim, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py new file mode 100644 index 000000000..9e8e905c5 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -0,0 +1,270 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + split_node = conv_node.split(self.split_size, dim=self.split_dim) + return split_node + + +class LinearSplitModel(nn.Module): + + def __init__(self, split_size, split_dim): + super().__init__() + self.split_size = split_size + self.split_dim = split_dim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + split_node = linear_node.split(self.split_size, dim=self.split_dim) + return split_node + + +def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(split_size=split_size, split_dim=split_dim).cuda() + + if model_cls.__name__ == 'ConvSplitModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearSplitModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearSplitModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvSplitModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearSplitModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + + split_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = split_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvSplitModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "split" + split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) + assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + for name in strategy_name_list: + print(name) + if model_cls.__name__ == 'ConvSplitModel': + + if split_dim == 0: + assert '[R, S1, R, R]_0' in strategy_name_list + assert '[R, S0, R, R]_1' in strategy_name_list + assert '[R, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, R]_4' in strategy_name_list + assert '[R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R]_6' in strategy_name_list + assert '[R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R]_10' in strategy_name_list + assert '[R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R]_15' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, R]_0' in strategy_name_list + assert '[S1, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R]_5' in strategy_name_list + assert '[R, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R]_10' in strategy_name_list + assert '[R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearSplitModel': + + if split_dim == 0: + assert '[R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1]_2' in strategy_name_list + assert '[R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0]_5' in strategy_name_list + assert '[R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R]_7' in strategy_name_list + assert '[R, R, S0, R]_8' in strategy_name_list + assert '[R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R]_10' in strategy_name_list + assert '[R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R]_19' in strategy_name_list + assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01]_22' in strategy_name_list + + if split_dim == 1: + assert '[S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('split_size', [2]) +@parameterize('split_dim', [0, 1, 2]) +@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +def test_split_handler(split_size, split_dim, model_cls): + world_size = 4 + run_func = partial(check_split_handler, + split_size=split_size, + split_dim=split_dim, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py new file mode 100644 index 000000000..5fda4de1a --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -0,0 +1,235 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSumModel(nn.Module): + + def __init__(self, sum_dims, keepdim): + super().__init__() + self.sum_dims = sum_dims + self.keepdim = keepdim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + if self.sum_dims is not None: + sum_node = torch.sum(linear_node, self.sum_dims, keepdim=self.keepdim) + else: + sum_node = torch.sum(linear_node, keepdim=self.keepdim) + return sum_node + + +def check_sum_handler(rank, sum_dims, keepdim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 24 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) + # return sum_1 + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + sum_node = list(graph.nodes)[3] + sum_strategies_vector = StrategiesVector(sum_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) + + sum_handler.register_strategy(compute_resharding_cost=False) + + # sum handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(sum_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in sum_strategies_vector] + + # check operation data mapping + mapping = sum_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "sum_1" + sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape + assert mapping['output'].logical_shape == sum_node_shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # check strategy name + if sum_dims == (0, 2) and keepdim == False: + assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + + if sum_dims == (0, 2) and keepdim == True: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == False: + assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == True: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('sum_dims', [(0, 2), 1]) +@parameterize('keepdim', [False, True]) +def test_sum_handler(sum_dims, keepdim): + world_size = 4 + run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py new file mode 100644 index 000000000..de35fe256 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class TensorConstructorModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + arange_node = torch.arange(x.size()[0]) + x = x + arange_node + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_where_handler(): + model = TensorConstructorModel() + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=2] = placeholder[target=x] + # %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {}) + # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) + # return add + graph = tracer.trace(model, meta_args={ + "x": torch.rand(10).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + arange_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(arange_node) + + # build handler + handler = TensorConstructorHandler(node=arange_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['output'].name == "arange" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([10]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + assert 'Replica Tensor Constructor' in strategy_name_list + + +if __name__ == '__main__': + test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index e4d12cd12..a861cb7f5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -22,6 +22,7 @@ class ReLuModel(nn.Module): return relu_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_elementwise_handler(): model = ReLuModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py new file mode 100644 index 000000000..08a702789 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -0,0 +1,265 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class ConvViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other, bias=None) + reshape_node = conv_node.view(*self.tgt_shape) + return reshape_node + + +class LinearViewModel(nn.Module): + + def __init__(self, tgt_shape): + super().__init__() + self.tgt_shape = tgt_shape + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + reshape_node = linear_node.view(*self.tgt_shape) + return reshape_node + + +def check_view_handler(rank, tgt_shape, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(tgt_shape).cuda() + + if model_cls.__name__ == 'ConvViewModel': + input = torch.rand(8, 8, 66, 66).to('cuda') + other = torch.rand(16, 8, 3, 3).to('cuda') + # index of conv node in computation graph + node_index = 2 + # total number of conv strategies + strategy_number = 16 + if model_cls.__name__ == 'LinearViewModel': + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + if model_cls.__name__ == 'ConvViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + + if model_cls.__name__ == 'LinearViewModel': + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + view_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(view_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + if model_cls.__name__ == 'ConvViewModel': + + conv_handler = ConvFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + if model_cls.__name__ == 'LinearViewModel': + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) + + view_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = view_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + if model_cls.__name__ == 'ConvViewModel': + assert mapping['input'].name == "conv2d" + else: + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size(tgt_shape) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + + if model_cls.__name__ == 'ConvViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list + assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list + assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list + + if model_cls.__name__ == 'LinearViewModel': + + if tgt_shape == (32, 4, 64, 16, 4): + assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list + + if tgt_shape == (8, 4, 4, 64, 16, 4): + assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +def test_view_handler(tgt_shape, model_cls): + world_size = 4 + run_func = partial(check_view_handler, + tgt_shape=tgt_shape, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py new file mode 100644 index 000000000..d02e1e31e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -0,0 +1,188 @@ +import copy +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import to_global +from colossalai.testing.comparison import assert_close + + +def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): + + model_to_compare = copy.deepcopy(model) + args_to_compare = [] + kwargs_to_compare = {} + for arg_index, input_tensor in enumerate(input_args): + + def wrapper(param, index): + + def hook_fn(grad): + grad_dict[index] = grad + + param.register_hook(hook_fn) + + arg_to_compare = copy.deepcopy(input_tensor) + + # only Tensors of floating point and complex dtype can require gradients + if arg_to_compare.dtype != torch.int64: + arg_to_compare.requires_grad = True + wrapper(arg_to_compare, arg_index) + + args_to_compare.append(arg_to_compare) + + for name, input_kwarg in input_kwargs.items(): + + def wrapper(param, name): + + def hook_fn(grad): + grad_dict[name] = grad + + param.register_hook(hook_fn) + + kwarg_to_compare = copy.deepcopy(input_kwarg) + + # only Tensors of floating point and complex dtype can require gradients + if kwarg_to_compare.dtype != torch.int64: + kwarg_to_compare.requires_grad = True + wrapper(kwarg_to_compare, name) + + kwargs_to_compare[name] = kwarg_to_compare + + return model_to_compare, args_to_compare, kwargs_to_compare + + +def numerical_test_for_node_strategy(model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = 'normal'): + for strategy_index in range(strategy_number): + print(f'#strategy_index: {strategy_index}') + # We need to copy the model to avoid do backward more than once in same graph + grad_to_compare_dict = {} + grad_to_shard_dict = {} + model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_compare_dict) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, + grad_to_shard_dict) + + tracer = ColoTracer() + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + if node_type == 'normal': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + elif node_type == 'following': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + solution[node_index + 1] = strategy_index + else: + node_vector = strategies_constructor.leaf_strategies[node_index] + strategy_to_keep = node_vector[strategy_index] + node_vector = [strategy_to_keep] + # solution construction + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + + # forward result compare + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + + # backward result compare + if isinstance(output, (tuple, list)): + loss = output[0].sum() + loss_to_compare = output_to_compare[0].sum() + else: + loss = output.sum() + loss_to_compare = output_to_compare.sum() + + loss_to_compare.backward() + loss.backward() + for key in grad_to_shard_dict.keys(): + grad_to_shard = grad_to_shard_dict[key] + grad_to_compare = grad_to_compare_dict[key] + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + # extract the strategy used in this iter + strategy_in_use = target_node.strategies_vector[strategy_index] + param_to_shard_dict = dict(gm.named_parameters()) + param_to_compare_dict = dict(model_to_compare.named_parameters()) + for name in param_to_shard_dict.keys(): + param_name = name.split('.')[-1] + if node_type == 'normal': + param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + else: + if 'weight' in name: + param_sharding_spec = None + + for node in list(graph.nodes): + if 'weight' in node.name: + param_sharding_spec = node.sharding_spec + + elif 'bias' in name: + param_sharding_spec = None + + for node in list(graph.nodes): + if 'bias' in node.name: + param_sharding_spec = node.sharding_spec + + assert param_sharding_spec is not None + grad_sharded = param_to_shard_dict[name].grad + grad_to_compare = param_to_compare_dict[name].grad + global_grad = to_global(grad_sharded, param_sharding_spec) + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + + +def assert_close_helper(first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = 'not defined'): + """ + This method is used to check whether the average difference between two tensors is as close as expected. + """ + try: + if isinstance(first, (tuple, list)): + for first_element, second_element in zip(first, second): + assert_close(first_element, second_element, rtol=rtol, atol=atol) + else: + assert_close(first, second, rtol=rtol, atol=atol) + except: + print(f'strategy index {strategy_index} encounter assert_close error on {type}') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py new file mode 100644 index 000000000..b504d59c9 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -0,0 +1,131 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +def _param_resharding_cost_assertion(node): + for strategy in node.strategies_vector: + for prev_node, resharding_cost in strategy.resharding_costs.items(): + if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM: + for cost in resharding_cost: + assert cost.fwd == 0 + assert cost.bwd == 0 + assert cost.total == 0 + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_linear_module(): + model = LinearModel(4, 8) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + _param_resharding_cost_assertion(linear_node) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_conv_module(): + model = ConvModel(3, 6, 2) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + _param_resharding_cost_assertion(conv_node) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py index 1f753522c..814edd279 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py @@ -10,6 +10,8 @@ from torch.fx import GraphModule from torchvision.models import resnet34, resnet50 from colossalai import device +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.constants import * from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser @@ -17,10 +19,6 @@ from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import ( - shape_consistency_pass, - solution_annotatation_pass, -) from colossalai.fx.tracer.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -140,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -153,8 +151,8 @@ def check_apply_bottleneck(rank, world_size, port): print(solution) for index, node in enumerate(graph.nodes): print(node.name, node.strategies_vector[solution[index]].name) - sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh) - shape_consistency_pass(gm) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) gm.recompile() nodes = [node for node in gm.graph.nodes] # TODO: wrap the gm to avoid the influence of the user training code @@ -164,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port): output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) assert output.shape == origin_output.shape - assert_close(output, origin_output) + assert_close(output, origin_output, rtol=1e-03, atol=1e-05) print("*******************backward starting*******************") cuda_rng_state = torch.cuda.get_rng_state() output.sum().backward() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py index 7dd0ae842..66cd3f3f7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py @@ -7,6 +7,8 @@ import torch.multiprocessing as mp import torch.nn as nn from torch.fx import GraphModule +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.solver import ( CostGraph, GraphAnalyser, @@ -15,10 +17,6 @@ from colossalai.auto_parallel.tensor_shard.solver import ( StrategiesConstructor, ) from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import ( - shape_consistency_pass, - solution_annotatation_pass, -) from colossalai.fx.tracer.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -62,7 +60,7 @@ def check_apply(rank, world_size, port): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -72,8 +70,8 @@ def check_apply(rank, world_size, port): solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) ret = solver.call_solver_serialized_args() solution = list(ret[0]) - sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh) - shape_consistency_pass(gm) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + gm = runtime_apply_pass(gm) gm.recompile() nodes = [node for node in gm.graph.nodes] # TODO: wrap the gm to avoid the influence of the user training code diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 23d866bbe..f4a5ae7ac 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,8 +3,13 @@ from torch.fx import GraphModule from torchvision.models import resnet50 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions, - StrategiesConstructor) +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -53,7 +58,7 @@ def test_cost_graph(): gm.recompile() graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py new file mode 100644 index 000000000..99abacd13 --- /dev/null +++ b/tests/test_device/test_alpha_beta.py @@ -0,0 +1,33 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + ab_dict = profiler.profile_ab() + for _, (alpha, beta) in ab_dict.items(): + assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 + + +@pytest.mark.skip(reason="Skip because assertion fails for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py new file mode 100644 index 000000000..e32bebdd9 --- /dev/null +++ b/tests/test_device/test_extract_alpha_beta.py @@ -0,0 +1,39 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_extract_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + + mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + for alpha in mesh_alpha: + assert alpha > 0 and alpha < 1e-3 + for beta in mesh_beta: + assert beta > 0 and beta < 1e-10 + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_extract_alpha_beta, + physical_devices=physical_devices, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py new file mode 100644 index 000000000..591eafb2a --- /dev/null +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -0,0 +1,36 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + best_logical_mesh = profiler.search_best_logical_mesh() + + if physical_devices == [0, 1, 2, 3]: + assert best_logical_mesh == [[0, 1], [2, 3]] + elif physical_devices == [0, 3]: + assert best_logical_mesh == [[0, 3]] + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 3914d57be..9949d49c1 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -2,11 +2,13 @@ import copy import re from typing import Callable -import colossalai import pytest import torch import torch.multiprocessing as mp import torchvision.models as tm +from torch.fx import GraphModule + +import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta @@ -14,7 +16,6 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -from torch.fx import GraphModule if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -94,6 +95,7 @@ def _run_ckpt_solver(rank): gpc.destroy() +@pytest.mark.skip("TODO(super-dainiu): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 08044c687..83df1bb5e 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,14 +1,15 @@ -import torch -import torch.nn.functional as F import pytest +import torch import torch.multiprocessing as mp -from torch.utils.checkpoint import checkpoint +import torch.nn.functional as F from torch.fx import GraphModule -from colossalai.fx import ColoTracer +from torch.utils.checkpoint import checkpoint + import colossalai -from colossalai.utils import free_port from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank): offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: - assert hasattr(node, 'activation_checkpoint') + assert 'activation_checkpoint' in node.meta # annotate the selected node for offload if node.name in offload_starts: - setattr(node, 'activation_offload', True) + node.meta['activation_offload'] = True gm = ColoGraphModule(model, graph) gm.recompile() @@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank): offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: - assert hasattr(node, 'activation_checkpoint') + assert 'activation_checkpoint' in node.meta # annotate the selected node for offload if node.name in offload_starts: - setattr(node, 'activation_offload', True) + node.meta['activation_offload'] = True gm = ColoGraphModule(model, graph) gm.recompile() diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index 56f25175e..6b3a49d18 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -1,14 +1,15 @@ -import torch -import torch.nn.functional as F import pytest +import torch import torch.multiprocessing as mp -from torch.utils.checkpoint import checkpoint +import torch.nn.functional as F from torch.fx import GraphModule -from colossalai.fx import ColoTracer +from torch.utils.checkpoint import checkpoint + import colossalai -from colossalai.utils import free_port from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - setattr(node, "activation_checkpoint", [0, 0, 0]) + node.meta['activation_checkpoint'] = [0, 0, 0] continue if node.name == "linear2": - setattr(node, "activation_checkpoint", [0, 0, None]) + node.meta['activation_checkpoint'] = [0, 0, None] if node.name == "linear3": - setattr(node, "activation_checkpoint", [0, 0, 1]) + node.meta['activation_checkpoint'] = [0, 0, 1] if node.name == "linear4": - setattr(node, "activation_checkpoint", [0, 1, None]) + node.meta['activation_checkpoint'] = [0, 1, None] if node.name == "linear5": - setattr(node, "activation_checkpoint", 1) + node.meta['activation_checkpoint'] = 1 gm = ColoGraphModule(model, graph) gm.recompile() @@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - setattr(node, "activation_checkpoint", [0, 0, 0]) + node.meta['activation_checkpoint'] = [0, 0, 0] continue if node.name == "linear2": - setattr(node, "activation_checkpoint", [0, 0, None]) + node.meta['activation_checkpoint'] = [0, 0, None] if node.name == "linear3": - setattr(node, "activation_checkpoint", [0, 0, 1]) + node.meta['activation_checkpoint'] = [0, 0, 1] if node.name == "linear4": - setattr(node, "activation_checkpoint", [0, 1, None]) + node.meta['activation_checkpoint'] = [0, 1, None] if node.name == "linear5": - setattr(node, "activation_checkpoint", 1) + node.meta['activation_checkpoint'] = 1 gm = ColoGraphModule(model, graph) gm.recompile() diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index edaeb50cb..5d090066c 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -1,14 +1,16 @@ import copy -import torch -import torch.nn.functional as F + import pytest +import torch import torch.multiprocessing as mp +import torch.nn.functional as F from torch.fx import GraphModule -from colossalai.fx import ColoTracer + import colossalai -from colossalai.utils import free_port from colossalai.core import global_context as gpc +from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.utils import free_port try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -83,16 +85,16 @@ def _run_offload_codegen(rank): # of input offload for node in graph.nodes: if node.name == "linear0": - setattr(node, "activation_offload", [0, True, False]) + node.meta['activation_offload'] = [0, True, False] if node.name == "linear1": - setattr(node, "activation_offload", [0, True, False]) + node.meta['activation_offload'] = [0, True, False] if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) + node.meta['activation_offload'] = [1, True, True] if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) + node.meta['activation_offload'] = [2, False, True] if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() @@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank): # of input offload for node in graph.nodes: if node.name == "linear0": - setattr(node, "activation_offload", [0, True, False]) + node.meta['activation_offload'] = [0, True, False] if node.name == "linear1": - setattr(node, "activation_offload", [0, True, False]) + node.meta['activation_offload'] = [0, True, False] if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) + node.meta['activation_offload'] = [1, True, True] if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) + node.meta['activation_offload'] = [2, False, True] if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) + node.meta['activation_checkpoint'] = [0] + node.meta['activation_offload'] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py index b17f2cdb6..a21a351f8 100644 --- a/tests/test_fx/test_complete_workflow.py +++ b/tests/test_fx/test_complete_workflow.py @@ -1,16 +1,18 @@ -import colossalai -import torch -import torch.nn as nn -import pytest -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai from colossalai.fx import ColoTracer -from colossalai.utils.model.lazy_init_context import LazyInitContext from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass -from colossalai.utils import free_port from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.lazy_init_context import LazyInitContext class MLP(torch.nn.Module): @@ -30,22 +32,27 @@ class MLP(torch.nn.Module): return x -def run_workflow(world_size): +def run_workflow(world_size, dev): # initailization with LazyInitContext() as ctx: model = MLP(16) + for param in model.parameters(): + assert param.is_meta + # tracing tracer = ColoTracer() graph = tracer.trace(model) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) # annotate - annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup()) + annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup(tp_degree=world_size)) annotated_gm.recompile() # materialization and sharding - ctx.lazy_init_parameters(annotated_gm) + ctx.lazy_init_parameters(annotated_gm, device=dev) + for param in model.parameters(): + assert not param.is_meta # # check sharding assert list(model.linear1.weight.shape) == [16 // world_size, 16] @@ -54,24 +61,27 @@ def run_workflow(world_size): # test forward to make sure that IR transform will produce the same results # like how ColoTensor would do it normally - data = torch.rand(4, 16) + data = torch.rand(4, 16, device=dev) non_fx_out = model(data) fx_out = annotated_gm(data) - assert torch.equal(non_fx_out, fx_out) + assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, dev, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_workflow(world_size) + run_workflow(world_size, dev) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('dev', ['cuda', 'cpu']) @rerun_if_address_is_in_use() -def test_complete_workflow(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +def test_complete_workflow(world_size, dev): + if dev == 'cpu' and world_size > 1: + return + run_func = partial(run_dist, world_size=world_size, dev=dev, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_complete_workflow(2) + test_complete_workflow(1, 'cuda') diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 08d20c894..6ef861bde 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -1,12 +1,13 @@ -import transformers -import torch import pytest +import torch +import transformers from hf_utils import split_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGHT = 16 +@pytest.mark.skip('balance split v2 is not ready') def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index a3699b660..a7550413f 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -1,12 +1,13 @@ -import transformers -import torch import pytest +import torch +import transformers from hf_utils import split_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGHT = 16 +@pytest.mark.skip('balance split v2 is not ready') def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index b973ac854..6181c5c07 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -1,6 +1,6 @@ -import transformers -import torch import pytest +import torch +import transformers from hf_utils import split_model_and_compare_output BATCH_SIZE = 64 @@ -9,6 +9,7 @@ NUM_EPOCHS = 2 NUM_CHUNKS = 1 +@pytest.mark.skip('balance split v2 is not ready') def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index a55ea54fe..1a9b36be8 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -1,12 +1,13 @@ import pytest -import transformers import torch +import transformers from hf_utils import split_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGHT = 16 +@pytest.mark.skip('balance split v2 is not ready') def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index d20d18842..16d016374 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -1,12 +1,13 @@ import pytest -import transformers import torch +import transformers from hf_utils import split_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGHT = 16 +@pytest.mark.skip('balance split v2 is not ready') def test_t5(): MODEL_LIST = [ transformers.T5Model, diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 7c3764f34..6fb1f6f4b 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -1,9 +1,10 @@ -import torch -import timm.models as tm -from timm_utils import split_model_and_compare_output import pytest +import timm.models as tm +import torch +from timm_utils import split_model_and_compare_output +@pytest.mark.skip('balance split v2 is not ready') def test_timm_models_without_control_flow(): MODEL_LIST = [ @@ -24,6 +25,7 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) +@pytest.mark.skip('balance split v2 is not ready') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py new file mode 100644 index 000000000..75c748705 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -0,0 +1,43 @@ +import pytest +import torch +import transformers +from topo_utils import split_model_and_get_DAG, check_topo, MLP + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + +def test_opt(): + MODEL_LIST = [ + MLP, + transformers.OPTModel, + ] + + CONFIGS = [ + {'dim': 10, 'layers': 12}, + transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), + ] + + def data_gen_MLP(): + x = torch.zeros((16, 10)) + kwargs = dict(x=x) + return kwargs + + def data_gen_OPT(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + DATAGEN = [ + data_gen_MLP, + data_gen_OPT, + ] + + for i, model_cls in enumerate(MODEL_LIST): + model = model_cls(config=CONFIGS[i]) + top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i]) + # print(f'{top_mod=}\n----\n{topo=}') + check_topo(top_mod, topo) + +if __name__ == '__main__': + test_opt() \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py new file mode 100644 index 000000000..55dd65201 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -0,0 +1,92 @@ +import torch +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.middleware.adaptor import get_fx_topology +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + +class MLP(torch.nn.Module): + def __init__(self, config={}): + super().__init__() + dim = config['dim'] + layers = config['layers'] + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(torch.nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +def split_model_and_get_DAG(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model) + + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + + return top_module, split_submodules[0]._topo + +def check_input(top_module, input_partition: Partition): + partition_output = input_partition.get_output_vals() + arg_pos = 0 + for node in top_module.graph.nodes: + if node.op == 'placeholder': + cur_checkee = partition_output[arg_pos] + to_partition_and_offset = cur_checkee.get() + assert len(to_partition_and_offset) == len(node.users.keys()) + arg_pos += 1 + + assert arg_pos == len(partition_output) + +def check_submod(top_module, part_id, mid_partition: Partition): + partition_input = mid_partition.get_input_vals() + partition_output = mid_partition.get_output_vals() + + cnt = 1 + cur_node = None + for node in top_module.graph.nodes: + if node.name.startswith('submod'): + cnt += 1 + if cnt == part_id: + cur_node = node + break + + assert len(partition_input) == len(cur_node.args) + assert len(partition_output) == len(cur_node.users) + +def check_topo(top_module, topo: Topo): + input_partition = topo.get_input_partition() + mid_partitions = topo.get_mid_partitions() + + check_input(top_module, input_partition) + for part_id, submod in mid_partitions.items(): + check_submod(top_module, part_id, submod) + \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index b308d99c2..5d47be2c7 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -1,13 +1,16 @@ +import inspect +import random + +import numpy as np +import pytest import torch import torchvision import torchvision.models as tm -from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from torch.fx import GraphModule from packaging import version -import random -import numpy as np -import inspect +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -16,6 +19,7 @@ torch.manual_seed(MANUAL_SEED) torch.backends.cudnn.deterministic = True +@pytest.mark.skip('balance split v2 is not ready') def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index a9921af3c..c71796018 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union import torch import torch.fx import torchvision.models as tm -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag from gpt_utils import gpt2_medium, gpt2_xl from torch.fx import symbolic_trace +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index 3fd39b393..a834951bb 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn -from colossalai.fx import ColoTracer from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint +from colossalai.fx import ColoTracer + class MLP(torch.nn.Module): @@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation(): for node in gm.graph.nodes: if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: - assert getattr(node, 'activation_checkpoint', -1) == 0 + assert node.meta.get('activation_checkpoint', -1) == 0 for node in gm.graph.nodes: if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: - assert getattr(node, 'activation_checkpoint', -1) == 1 + assert node.meta.get('activation_checkpoint', -1) == 1 tracer = ColoTracer(trace_act_ckpt=False) graph = tracer.trace(module) diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py new file mode 100644 index 000000000..afa30a217 --- /dev/null +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -0,0 +1,114 @@ +import torch + +from colossalai.fx import ColoGraphModule, ColoTracer + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def test_linear_module(): + model = LinearModel(3, 6) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + linear_node = node_list[3] + add_node = node_list[4] + assert weight_node._meta_data.shape == (6, 3) + assert bias_node._meta_data.shape == (6,) + assert linear_node._meta_data.shape == (3, 6) + assert add_node._meta_data.shape == (3, 6) + + +def test_conv_module(): + model = ConvModel(3, 6, 2) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + for node in node_list: + if node.op == 'output': + continue + assert hasattr(node, '_meta_data') + weight_node = node_list[1] + bias_node = node_list[2] + conv_node = node_list[3] + view_node = node_list[4] + add_node = node_list[5] + assert weight_node._meta_data.shape == (6, 3, 2, 2) + assert bias_node._meta_data.shape == (6,) + assert conv_node._meta_data.shape == (4, 6, 63, 63) + assert view_node._meta_data.shape == (6, 1, 1) + assert add_node._meta_data.shape == (4, 6, 63, 63) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module() diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py similarity index 77% rename from tests/test_fx/test_tracer/test_hf_model/utils.py rename to tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index fb0702455..6d93fe040 100644 --- a/tests/test_fx/test_tracer/test_hf_model/utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -3,24 +3,19 @@ from numpy import isin from torch.fx import GraphModule from torch.utils._pytree import tree_flatten -from colossalai.fx import ColoTracer +from colossalai.fx import symbolic_trace def trace_model_and_compare_output(model, data_gen): # must turn on eval mode to ensure the output is consistent model.eval() - # make sure that the model is traceable - tracer = ColoTracer() - try: kwargs = data_gen() meta_args = {k: v.to('meta') for k, v in kwargs.items()} - graph = tracer.trace(root=model, meta_args=meta_args) + gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() # run forward inputs = data_gen() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 5837340fa..9c36b0c9c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 1a66b1151..62273e2d5 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ab6e08694..04e874bec 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -1,11 +1,15 @@ -import diffusers import pytest import torch import transformers -from torch.fx import GraphModule -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output -from colossalai.fx import ColoTracer +from colossalai.fx import symbolic_trace + +try: + import diffusers + HAS_DIFFUSERS = True +except ImportError: + HAS_DIFFUSERS = False BATCH_SIZE = 2 SEQ_LENGTH = 5 @@ -16,6 +20,7 @@ LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8) TIME_STEP = 2 +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") def test_vae(): MODEL_LIST = [ diffusers.AutoencoderKL, @@ -26,11 +31,7 @@ def test_vae(): model = model_cls() sample = torch.zeros(LATENTS_SHAPE) - tracer = ColoTracer() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() @@ -80,6 +81,7 @@ def test_clip(): trace_model_and_compare_output(model, data_gen) +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") @pytest.mark.skip(reason='cannot pass the test yet') def test_unet(): MODEL_LIST = [ @@ -91,11 +93,7 @@ def test_unet(): model = model_cls() sample = torch.zeros(LATENTS_SHAPE) - tracer = ColoTracer() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index ae2e752f9..ad4c9684d 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -1,12 +1,14 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 +# TODO: remove this skip once we handle the latest gpt model +@pytest.mark.skip def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index c39e97a16..06260176e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index b6749c828..71e782fdd 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,7 +1,7 @@ import pytest import torch import transformers -from utils import trace_model_and_compare_output +from hf_tracer_utils import trace_model_and_compare_output BATCH_SIZE = 1 SEQ_LENGTH = 16 diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 1ce679d4c..28ec3d825 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -1,11 +1,11 @@ -import torch -import timm.models as tm -from colossalai.fx import ColoTracer -from torch.fx import GraphModule import pytest +import timm.models as tm +import torch + +from colossalai.fx import symbolic_trace -def trace_and_compare(model_cls, tracer, data, meta_args=None): +def trace_and_compare(model_cls, data, meta_args=None): # trace model = model_cls() @@ -14,15 +14,13 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): # without this statement, the torch.nn.functional.batch_norm will always be in training mode model.eval() - graph = tracer.trace(root=model, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model, meta_args=meta_args) # run forward with torch.no_grad(): fx_out = gm(data) non_fx_out = model(data) - + # compare output if isinstance(fx_out, tuple): # some models produce tuple as output @@ -30,7 +28,8 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}' else: assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out, + atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' def test_timm_models_without_control_flow(): @@ -47,11 +46,10 @@ def test_timm_models_without_control_flow(): tm.deit_base_distilled_patch16_224, ] - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) for model_cls in MODEL_LIST: - trace_and_compare(model_cls, tracer, data) + trace_and_compare(model_cls, data) def test_timm_models_with_control_flow(): @@ -62,13 +60,12 @@ def test_timm_models_with_control_flow(): tm.swin_transformer.swin_base_patch4_window7_224 ] - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) meta_args = {'x': data.to('meta')} for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: - trace_and_compare(model_cls, tracer, data, meta_args) + trace_and_compare(model_cls, data, meta_args) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 894810fe6..702c5f8f6 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -1,19 +1,16 @@ -from colossalai.fx import ColoTracer import torch -from torch.fx import GraphModule, Tracer + +from colossalai.fx import symbolic_trace def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False): data = data_gen() concrete_args = data if need_concrete else {} meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} - tracer = ColoTracer() model.eval() - graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) with torch.no_grad(): non_fx_out = model(**data) @@ -24,8 +21,9 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa fx_out = gm(**data) if isinstance(fx_out, tuple): for non_fx, fx in zip(non_fx_out, fx_out): - assert torch.allclose(non_fx, - fx), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' else: assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out, + atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index 0f1f294e4..dbe8a62e7 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,19 +1,22 @@ -from colossalai.fx.tracer import meta_patch -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.fx.tracer.meta_patch.patched_function import python_ops +import pytest import torch -from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.models import deepfm, dlrm -import colossalai.fx as fx -import pdb -from torch.fx import GraphModule + +from colossalai.fx import symbolic_trace + +try: + from torchrec.models import deepfm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True BATCH = 2 SHAPE = 10 +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') def test_torchrec_deepfm_models(): MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] @@ -36,9 +39,6 @@ def test_torchrec_deepfm_models(): # Dense Features features = torch.rand((BATCH, SHAPE)) - # Tracer - tracer = ColoTracer() - for model_cls in MODEL_LIST: # Initializing model if model_cls == deepfm.DenseArch: @@ -53,9 +53,7 @@ def test_torchrec_deepfm_models(): model = model_cls(ebc) # Setup GraphModule - graph = tracer.trace(model) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 5999a1abf..2f9fd8fe5 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,19 +1,23 @@ -from colossalai.fx.tracer import meta_patch -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.fx.tracer.meta_patch.patched_function import python_ops import torch -from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.models import deepfm, dlrm -import colossalai.fx as fx -import pdb -from torch.fx import GraphModule + +from colossalai.fx import symbolic_trace + +try: + from torchrec.models import dlrm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True + +import pytest BATCH = 2 SHAPE = 10 +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') def test_torchrec_dlrm_models(): MODEL_LIST = [ dlrm.DLRM, @@ -46,8 +50,6 @@ def test_torchrec_dlrm_models(): # Sparse Features sparse_features = torch.rand((BATCH, len(keys), SHAPE)) - # Tracer - tracer = ColoTracer() for model_cls in MODEL_LIST: # Initializing model @@ -72,12 +74,9 @@ def test_torchrec_dlrm_models(): # Setup GraphModule if model_cls == dlrm.InteractionV2Arch: concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} - graph = tracer.trace(model, concrete_args=concrete_args) + gm = symbolic_trace(model, concrete_args=concrete_args) else: - graph = tracer.trace(model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 046a0dabe..2a6c6ae16 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -2,8 +2,8 @@ import torch import torchvision import torchvision.models as tm from packaging import version -from colossalai.fx import ColoTracer -from torch.fx import GraphModule + +from colossalai.fx import symbolic_trace def test_torchvision_models(): @@ -20,7 +20,6 @@ def test_torchvision_models(): torch.backends.cudnn.deterministic = True - tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) for model_cls in MODEL_LIST: @@ -30,10 +29,7 @@ def test_torchvision_models(): else: model = model_cls() - graph = tracer.trace(root=model) - - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() + gm = symbolic_trace(model) model.eval() gm.eval() diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py index ed9d51d9a..daf386d6d 100644 --- a/tests/test_gemini/test_param_op.py +++ b/tests/test_gemini/test_param_op.py @@ -1,38 +1,9 @@ -from colossalai.gemini.paramhooks import BaseParamHookMgr -from torch import nn -import torch -import torch.nn.functional as F import copy +import torch -class SubNet(nn.Module): - - def __init__(self, out_features) -> None: - super().__init__() - self.bias = nn.Parameter(torch.zeros(out_features)) - - def forward(self, x, weight): - return F.linear(x, weight, self.bias) - - -class Net(nn.Module): - - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.sub_fc = SubNet(5) - self.fc2 = nn.Linear(5, 1) - - def forward(self, x): - x = self.fc1(x) - x = self.sub_fc(x, self.fc1.weight) - x = self.fc1(x) - x = self.fc2(x) - return x - - -def net_data(): - return (torch.randn(2, 5, dtype=torch.float, device='cuda'),) +from colossalai.gemini.paramhooks import BaseParamHookMgr +from tests.components_to_test.registry import non_distributed_component_funcs def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: @@ -41,54 +12,68 @@ def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> boo return torch.allclose(tensor_a, tensor_b) +def run_model(model, inputs, label, criterion, use_param_hook=False): + if use_param_hook: + + class HooKWrapper: + + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + if criterion: + y = model(inputs) + loss = criterion(y, label) + else: + loss = model(inputs, label) + loss = loss.float() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + def test_base_param_hook(): - torch.manual_seed(0) - model = Net(checkpoint=True).cuda() - model.train() - inputs = net_data() + test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] + # test_models = ['bert'] - def run_model(model, inputs, use_param_hook=False): - if use_param_hook: + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() - class HooKWrapper: + torch.manual_seed(0) + model = model_builder(checkpoint=True).cuda() + model.train() - def __init__(self) -> None: - self.hook_triggered_times = 0 + for i, (inputs, label) in enumerate(train_dataloader): + if i > 0: + break + model_copy = copy.deepcopy(model) - def wrapper_func(self): + run_model(model, inputs.cuda(), label.cuda(), criterion, False) + ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - y = model(*inputs) - loss = y.sum() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - model_copy = copy.deepcopy(model) - - run_model(model, inputs, False) - ret2 = run_model(model_copy, inputs, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" if __name__ == '__main__': diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py new file mode 100644 index 000000000..294868458 --- /dev/null +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -0,0 +1,52 @@ +from copy import deepcopy + +import numpy as np +import torch + +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs + + +def test_runtime_mem_tracer(): + test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(checkpoint=False) + + model_bk = deepcopy(model) + runtime_mem_tracer = RuntimeMemTracer(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + data = data.cuda() + label = label.cuda() + + run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) + + for p1, p2 in zip(model_bk.parameters(), model.parameters()): + torch.allclose(p1.to(torch.half), p2) + + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 + print("cuda_non_model_data_list", len(cuda_non_model_data_list)) + print(non_model_data_list) + + cnt1 = 0 + for p in runtime_mem_tracer.parameters_in_runtime_order(): + cnt1 += 1 + cnt2 = 0 + for p in model.parameters(): + cnt2 += 1 + assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + del model + + +if __name__ == '__main__': + test_runtime_mem_tracer() diff --git a/tests/test_gemini/test_stateful_tensor_mgr.py b/tests/test_gemini/test_stateful_tensor_mgr.py deleted file mode 100644 index 39c07f279..000000000 --- a/tests/test_gemini/test_stateful_tensor_mgr.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory import colo_set_process_memory_fraction -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from torch.nn.parameter import Parameter -from typing import List -from functools import partial - -from colossalai.gemini import StatefulTensorMgr -from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy - - -class Net(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - # each parameter is 128 MB - self.p0 = Parameter(torch.empty(1024, 1024, 32)) - self.p1 = Parameter(torch.empty(1024, 1024, 32)) - self.p2 = Parameter(torch.empty(1024, 1024, 32)) - - -def limit_cuda_memory(memory_in_g: float): - cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory - fraction = (memory_in_g * 1024**3) / cuda_capacity - colo_set_process_memory_fraction(fraction) - - -def run_stm(): - # warmup phase use 20% CUDA memory to store params - # only 2 params can be on CUDA - limit_cuda_memory(1.26) - model = Net() - for p in model.parameters(): - p.colo_attr = ShardedParamV2(p, set_data_none=True) - GLOBAL_MODEL_DATA_TRACER.register_model(model) - mem_collector = MemStatsCollector() - tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector) - stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy) - stateful_tensors = [p.colo_attr.sharded_data_tensor for p in model.parameters()] - stateful_tensor_mgr.register_stateful_tensor_list(stateful_tensors) - - mem_collector.start_collection() - # Compute order: 0 1 2 0 1 - # warmup - # use naive eviction strategy - apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr) - mem_collector.sample_model_data() - mem_collector.sample_overall_data() - apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) - mem_collector.sample_model_data() - mem_collector.sample_overall_data() - apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr) - mem_collector.sample_model_data() - mem_collector.sample_overall_data() - apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) - mem_collector.sample_model_data() - mem_collector.sample_overall_data() - apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) - mem_collector.sample_model_data() - mem_collector.finish_collection() - stateful_tensor_mgr.finish_iter() - - # warmup done - # only 2 params can be on CUDA - limit_cuda_memory(0.26 / tensor_placement_policy._steady_cuda_cap_ratio) - # use OPT-like eviction strategy - apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr) - apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) - apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr) - apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) - apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) - - -def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter], - stateful_tensor_mgr: StatefulTensorMgr): - compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE) - for p in model.parameters(): - if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD: - p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD) - stateful_tensor_mgr.adjust_layout() - print_stats(model) - device = torch.device(torch.cuda.current_device()) - cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust] - for n, p in model.named_parameters(): - if hash(p) in cuda_param_after_adjust: - assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}' - else: - assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu') - - -def print_stats(model: torch.nn.Module): - msgs = [] - for n, p in model.named_parameters(): - msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})') - print(f'[ {", ".join(msgs)} ]') - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_stm() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_stateful_tensor_manager(world_size=1): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - # this unit test can pass if available CUDA memory >= 1.5G - test_stateful_tensor_manager() diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index fa7a9b1b5..7d192fc63 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -1,70 +1,72 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -from functools import partial -from colossalai.gemini.chunk import ChunkManager -from colossalai.testing import rerun_if_address_is_in_use, parameterize -from colossalai.utils import free_port -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec -from tests.test_tensor.common_utils import debug_print - -CUDA_MEM_0 = {False: 512, True: 1024} -CUDA_MEM_1 = {False: 0, True: 1024} -CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} - - -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) -def exam_chunk_memory(keep_gathered, pin_memory): - pg = ProcessGroup() - - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) - - params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] - config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} - - chunk_manager = ChunkManager(config) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 - - for p in params: - chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory) - chunk_manager.close_all_groups() - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] - - chunks = chunk_manager.get_chunks(params) - - for chunk in chunks: - chunk_manager.access_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] - - for chunk in chunks: - chunk_manager.release_chunk(chunk) - - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] - - for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_chunk_memory() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_chunk_manager(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_chunk_manager(2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_tensor.common_utils import debug_print + +CUDA_MEM_0 = {False: 512, True: 1024} +CUDA_MEM_1 = {False: 0, True: 1024} +CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} + + +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_memory(keep_gathered, pin_memory): + pg = ProcessGroup() + + debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) + + params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] + config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} + + chunk_manager = ChunkManager(config) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 + + for p in params: + chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.close_all_groups() + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + chunks = chunk_manager.get_chunks(params) + + for chunk in chunks: + chunk_manager.access_chunk(chunk) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + + for chunk in chunks: + chunk_manager.release_chunk(chunk) + + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_memory() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_manager(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_manager(2) diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 57a49314f..96855410b 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -1,121 +1,125 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -import torch.distributed as dist -from functools import partial -from colossalai.testing import rerun_if_address_is_in_use, parameterize -from colossalai.utils import free_port, get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.tensor import ColoParameter -from colossalai.gemini import TensorState -from colossalai.gemini.chunk import Chunk - - -def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) - dist.all_reduce(temp) - return temp.item() - - -def add_param(param_list, param_cp_list, *args, **kwargs): - param = ColoParameter(torch.randn(*args, **kwargs)) - param_list.append(param) - param_cp_list.append(param.clone()) - - -def check_euqal(param, param_cp): - if param.device != param_cp.device: - temp = param.data.to(param_cp.device) - else: - temp = param.data - return torch.equal(temp, param_cp.data) - - -@parameterize('init_device', [None, torch.device('cpu')]) -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory): - world_size = torch.distributed.get_world_size() - pg = ColoProcessGroup() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - keep_gathered=keep_gathered, - pin_memory=pin_memory) - - param_list = [] - param_cp_list = [] - - add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') - add_param(param_list, param_cp_list, 4, 4) - add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') - add_param(param_list, param_cp_list, 1, 1, 5) - - for param in param_list: - my_chunk.append_tensor(param) - assert my_chunk.utilized_size == 597 - for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) - my_chunk.close_chunk() - - if keep_gathered is False: - assert my_chunk.cpu_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cpu' - assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) - else: - assert my_chunk.chunk_total.size(0) == 1024 - assert my_chunk.device_type == 'cuda' - assert not my_chunk.can_move - - assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size - flag = my_chunk.has_inf_or_nan - assert not flag, "has_inf_or_nan is {}".format(flag) - - my_chunk.access_chunk() - assert my_chunk.device_type == 'cuda' - for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) - - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 - my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3 - assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1 - assert not my_chunk.can_release - - for param in param_list: - my_chunk.tensor_trans_state(param, TensorState.COMPUTE) - my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) - - assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4 - assert my_chunk.can_reduce - my_chunk.reduce() - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 - - if keep_gathered is False: - assert my_chunk.cuda_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cuda' - assert my_chunk.can_move - else: - assert my_chunk.chunk_total.size(0) == 1024 - assert my_chunk.device_type == 'cuda' - assert not my_chunk.can_move - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_chunk_basic() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2, 4]) -@rerun_if_address_is_in_use() -def test_chunk_function(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_chunk_function(4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini import TensorState +from colossalai.gemini.chunk import Chunk +from colossalai.tensor import ColoParameter +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device + + +def dist_sum(x): + temp = torch.tensor([x], device=get_current_device()) + dist.all_reduce(temp) + return temp.item() + + +def add_param(param_list, param_cp_list, *args, **kwargs): + param = ColoParameter(torch.randn(*args, **kwargs)) + param_list.append(param) + param_cp_list.append(param.clone()) + + +def check_euqal(param, param_cp): + if param.device != param_cp.device: + temp = param.data.to(param_cp.device) + else: + temp = param.data + return torch.equal(temp, param_cp.data) + + +@parameterize('init_device', [None, torch.device('cpu')]) +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory): + world_size = torch.distributed.get_world_size() + pg = ColoProcessGroup() + my_chunk = Chunk(chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory) + + param_list = [] + param_cp_list = [] + + add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 4, 4) + add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 1, 1, 5) + + for param in param_list: + my_chunk.append_tensor(param) + assert my_chunk.utilized_size == 597 + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + my_chunk.close_chunk() + + if keep_gathered is False: + assert my_chunk.cpu_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cpu' + assert my_chunk.can_move + my_chunk.shard_move(get_current_device()) + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size + flag = my_chunk.has_inf_or_nan + assert not flag, "has_inf_or_nan is {}".format(flag) + + my_chunk.access_chunk() + assert my_chunk.device_type == 'cuda' + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3 + assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 + assert not my_chunk.can_release + + for param in param_list: + my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD) + my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + + assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 + assert my_chunk.can_reduce + my_chunk.reduce() + assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 + + if keep_gathered is False: + assert my_chunk.cuda_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cuda' + assert my_chunk.can_move + else: + assert my_chunk.cuda_global_chunk.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_basic() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@rerun_if_address_is_in_use() +def test_chunk_function(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_function(4) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index eb433f2c3..af98878e9 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -4,19 +4,23 @@ import pytest import torch import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal +from tests.test_tensor.common_utils import set_seed def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): @@ -27,38 +31,41 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): chunk_manager.access_chunk(chunk) for (p0, p1) in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits + assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) +@parameterize('init_device', [get_current_device()]) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') +@parameterize('keep_gather', [False, True]) +@parameterize('model_name', ['gpt2', 'bert', 'albert']) +@parameterize('use_grad_checkpoint', [False, True]) +def exam_gpt_fwd_bwd(placement_policy, + keep_gather, + model_name: str, + use_grad_checkpoint: bool = False, + init_device=get_current_device()): + + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device()): - model = model_builder() + set_seed(42) + with ColoInitContext(device=init_device): + model = model_builder(use_grad_checkpoint) - torch_model = model_builder().cuda() + set_seed(42) + torch_model = model_builder(use_grad_checkpoint).cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]['keep_gathered'] = keep_gather chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) pg = ProcessGroup() amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) @@ -66,22 +73,24 @@ def exam_gpt_fwd_bwd(placement_policy): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - model.eval() - torch_model.eval() - set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. if i > 0: break + input_ids, label = input_ids.cuda(), label.cuda() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - model.backward(loss) + torch_optim.zero_grad() + zero_optim.zero_grad() - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( - torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) + # set random seed is same as torch_model.eval() + set_seed(42) + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + + assert torch.equal(torch_loss, loss) check_grad(model, torch_model) @@ -101,4 +110,4 @@ def test_gpt(world_size): if __name__ == '__main__': - test_gpt(1) + test_gpt(4) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py new file mode 100644 index 000000000..7fce84a50 --- /dev/null +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -0,0 +1,108 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed + +# run gemini use the runtime memory tracer + + +@parameterize('placement_policy', ['auto']) +@parameterize('keep_gather', [False]) +@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) +@parameterize('use_grad_checkpoint', [False, True]) +def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(use_grad_checkpoint) + + print(f'model_name {model_name}') + runtime_mem_tracer = RuntimeMemTracer(model) + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 0: + break + input_ids, label = input_ids.cuda(), label.cuda() + + # mem tracing + if i == 0: + run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + memstats = runtime_mem_tracer.memstats() + runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list + print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) + print('runtime tracer: ', runtime_tracer_non_model_data) + print([memstats.param_used_step(p) for p in model.parameters()]) + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gather + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + set_seed(pg.dp_local_rank()) + for i, (input_ids, label) in enumerate(train_dataloader): + # you can only test a single fwd + bwd. + # after bwd param is grad for Gemini, due to the chunk reuse optimization. + # print(f'iteration {i}') + if i > 4: + break + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42) + loss = run_fwd_bwd(model, input_ids, label, criterion, model) + + gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + + # print('gemini non model data:', gemini_non_model_data) + + assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ + f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gemini_use_rmt() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gemini_use_rmt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gemini_use_rmt(1) diff --git a/tests/test_gemini/update/test_get_torch_model.py b/tests/test_gemini/update/test_get_torch_model.py new file mode 100644 index 000000000..e6d586b37 --- /dev/null +++ b/tests/test_gemini/update/test_get_torch_model.py @@ -0,0 +1,59 @@ +import os +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel import GeminiDDP +from colossalai.nn.parallel.utils import get_static_torch_model +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2']) +def run_convert_torch_module(model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, _, _, _, _ = get_components_func() + + with ColoInitContext(device=torch.device("cpu")): + model = model_builder(checkpoint=False) + model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) + pytorch_model = get_static_torch_model(model, only_rank_0=False) + + for n, p in pytorch_model.named_parameters(): + assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}" + + # get the static model should not change the original model + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter) + + for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()): + assert pn == cn + assert id(pm) != id(cm) + for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)): + assert id(pp) != id(cp) + assert pp.shape == cp.shape + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_convert_torch_module() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_convert_torch_module(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_convert_torch_module(2) diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py new file mode 100644 index 000000000..fda1cf8cf --- /dev/null +++ b/tests/test_gemini/update/test_grad_clip.py @@ -0,0 +1,115 @@ +from functools import partial +from time import time + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +def exam_grad_clipping(placement_policy, model_name: str): + set_seed(1912) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + + model.train() + torch_model.train() + + set_seed(dist.get_rank() * 3 + 128) + for i, (data, label) in enumerate(train_dataloader): + if i > 2: + break + data = data.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) + loss = run_fwd_bwd(model, data, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) + torch_optim.step() + zero_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_grad_clipping() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_grad_clip(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_clip(2) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 62822f133..07e6e65f2 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -1,25 +1,32 @@ from functools import partial -from time import time import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal +from tests.test_tensor.common_utils import debug_print, set_seed + +# this model is large enough to slice to chunks +TEST_MODELS = ['gpt2'] +# these models are too small, all parameters in these models are compacted into one chunk +EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] def check_param(model: ZeroDDP, torch_model: torch.nn.Module): @@ -29,35 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): +@parameterize('model_name', TEST_MODELS) +def exam_model_step(placement_policy, model_name: str): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device()): + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): model = model_builder() - torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) + p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) @@ -72,25 +75,72 @@ def exam_gpt_fwd_bwd(placement_policy): model = ZeroDDP(model, gemini_manager, pin_memory=True) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): + if i > 2: + break + input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', EXAMPLE_MODELS) +def exam_tiny_example(placement_policy, model_name: str): + set_seed(2008) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) - # debug_print([0], zero_logits, torch_logits) + input_ids = input_ids.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() @@ -101,16 +151,17 @@ def exam_gpt_fwd_bwd(placement_policy): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() + exam_model_step() + exam_tiny_example() @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() -def test_gpt(world_size): +def test_optim(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_gpt(1) + test_optim(1) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index ea2783fb8..b902bb0f0 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -19,9 +19,10 @@ from tests.test_tensor.common_utils import debug_print, set_seed @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) -def exam_state_dict(placement_policy, keep_gathered): +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, keep_gathered, model_name: str): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()): @@ -44,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered): torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) @@ -53,9 +52,10 @@ def exam_state_dict(placement_policy, keep_gathered): @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) -def exam_load_state_dict(placement_policy, keep_gathered): +@parameterize('model_name', ['gpt2', 'bert']) +def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()): @@ -82,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered): zero_dict = model.state_dict(only_rank_0=False) for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index 74761668a..7f53415bf 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -9,12 +9,12 @@ import colossalai from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed @@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): set_seed(dist.get_rank() * 3 + 128) model.train() - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: break optim.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optim.backward(loss) diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py index 5e1681da9..668b8a334 100644 --- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -1,496 +1,552 @@ -import torch -import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier, - VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D) -from colossalai.utils import get_current_device, print_rank_0 -from torch.nn import Parameter - -from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal - - -def check_linear_col(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = 2 * HIDDEN_SIZE - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = A_master.clone() - A.requires_grad = True - - W_shape = (OUTPUT_SIZE, INPUT_SIZE) - W_master = torch.randn(W_shape, dtype=dtype, device=device) - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=0)[i] - W = W.clone() - W.requires_grad = True - - B_shape = (OUTPUT_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - dist.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - B = B.clone() - B.requires_grad = True - - layer.weight = Parameter(W) - layer.bias = Parameter(B) - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - - check_equal(out, C) - print_rank_0('linear_col forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, layer.bias.grad) - - print_rank_0('linear_col backward: pass') - - -def check_linear_row(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = 2 * HIDDEN_SIZE - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = torch.chunk(A_master, DEPTH, dim=-1)[i] - A = A.clone() - A.requires_grad = True - - W_shape = (INPUT_SIZE, OUTPUT_SIZE) - W_master = torch.randn(W_shape, dtype=dtype, device=device) - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=-1)[i] - W = W.clone() - W.requires_grad = True - - B_shape = (INPUT_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - dist.broadcast(B_master, src=0) - B = B_master.clone() - B.requires_grad = True - - layer.weight = Parameter(W) - layer.bias = Parameter(B) - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = C_master.clone() - - check_equal(out, C) - print_rank_0('linear_row forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - check_equal(B_grad, layer.bias.grad) - - print_rank_0('linear_row backward: pass') - - -def check_embed(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] - embed.weight.data.copy_(weight) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = embed(A) - - A_master = A_master.clone() - C_master = embed_master(A_master) - C = C_master.clone() - check_equal(out, C) - print_rank_0('embed forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - grad_master = grad_master.clone() - C_master.backward(grad_master) - - B_grad = embed_master.weight.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] - check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') - - -def check_vocab_parallel_embed(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=0)[i] - embed.weight.data.copy_(weight) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = embed(A) - - A_master = A_master.clone() - C_master = embed_master(A_master) - C = C_master.clone() - check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - grad_master = grad_master.clone() - C_master.backward(grad_master) - - B_grad = embed_master.weight.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') - - -def check_classifier_no_given_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - env.parallel_input_1d = False - parallel_input_1d = env.parallel_input_1d - layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) - layer_master = layer_master.to(dtype).to(device) - - W_master = layer_master.weight.data - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=-1)[i] - layer.weight.data.copy_(W) - B_master = layer_master.bias.data - dist.broadcast(B_master, src=0) - B = B_master.clone() - layer.bias.data.copy_(B) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - if parallel_input_1d: - A = torch.chunk(A_master, DEPTH, dim=-1)[i] - A = A.clone() - else: - A = A_master.clone() - A.requires_grad = True - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - C_master = layer_master(A_master) - C = C_master.clone() - - check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - if parallel_input_1d: - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] - check_equal(A_grad, A.grad) - - W_grad = layer_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = layer_master.bias.grad - check_equal(B_grad, layer.bias.grad) - - print_rank_0('classifier (no given weight) backward: pass') - - -def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) - layer_master = layer_master.to(dtype).to(device) - - W_master = layer_master.weight.data - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=0)[i] - layer.weight.data.copy_(W) - B_master = layer_master.bias.data - dist.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - layer.bias.data.copy_(B) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = A_master.clone() - A.requires_grad = True - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - C_master = layer_master(A_master) - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - - check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - check_equal(A_grad, A.grad) - - W_grad = layer_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = layer_master.bias.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, layer.bias.grad) - - print_rank_0('vocab parallel classifier (no given weight) backward: pass') - - -def check_classifier_given_embed_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] - embed.weight.data.copy_(weight) - - env.parallel_input_1d = False - layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = layer(embed(A)) - - A_master = A_master.clone() - C_master = layer_master(embed_master(A_master)) - C = C_master.clone() - check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - - W_grad = embed_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, embed.weight.grad) - - print_rank_0('classifier (given embed weight) backward: pass') - - -def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=0)[i] - embed.weight.data.copy_(weight) - - env.parallel_input_1d = False - layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = layer(embed(A)) - - A_master = A_master.clone() - C_master = layer_master(embed_master(A_master)) - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - - W_grad = embed_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, embed.weight.grad) - - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') - - -def check_vocab_parallel_loss(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - criterion = VocabParallelCrossEntropyLoss1D() - criterion_master = torch.nn.CrossEntropyLoss() - - out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) - torch.distributed.broadcast(out_master, src=0) - torch.distributed.broadcast(target_master, src=0) - out = torch.chunk(out_master, DEPTH, dim=-1)[i] - out = out.clone() - out.requires_grad = True - - loss = criterion(out, target_master) - - out_master = out_master.clone() - out_master.requires_grad = True - loss_master = criterion_master(out_master, target_master) - check_equal(loss, loss_master) - print_rank_0('vocab parallel loss forward: pass') - - loss.backward() - loss_master.backward() - - out_grad = out_master.grad - out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] - check_equal(out_grad, out.grad) - print_rank_0('vocab parallel loss backward: pass') +import torch +import torch.distributed as dist +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import ( + Classifier1D, + Embedding1D, + Linear1D_Col, + Linear1D_Row, + VanillaClassifier, + VocabParallelClassifier1D, + VocabParallelCrossEntropyLoss1D, + VocabParallelEmbedding1D, +) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear_col(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('linear_col forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_col backward: pass') + + +def check_linear_row(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_row backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + env.parallel_input_1d = False + parallel_input_1d = env.parallel_input_1d + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = B_master.clone() + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + if parallel_input_1d: + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + else: + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + if parallel_input_1d: + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = C_master.clone() + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + criterion = VocabParallelCrossEntropyLoss1D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=-1)[i] + out = out.clone() + out.requires_grad = True + + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel loss backward: pass') + + +@torch.no_grad() +def check_linear_row_stream_inference(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + stream_chunk_num = 4 + assert HIDDEN_SIZE % stream_chunk_num == 0 + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + layer.chunk_weight() + layer.eval() + + out = layer(A) + + A_master = A_master.clone() + W_master = W_master.clone() + B_master = B_master.clone() + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index cbdcb1b72..897590f0d 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -1,46 +1,49 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_1d.check_layer_1d import * - -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) - - -def check_layer(rank, world_size, port): - disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - check_linear_col() - check_linear_row() - check_embed() - check_vocab_parallel_embed() - check_classifier_no_given_weight() - check_vocab_parallel_classifier_no_given_weight() - check_classifier_given_embed_weight() - check_vocab_parallel_classifier_given_embed_weight() - check_vocab_parallel_loss() - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_1d(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_1d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from checks_1d.check_layer_1d import * + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + check_linear_col() + check_linear_row() + check_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_vocab_parallel_loss() + + check_linear_row_stream_inference() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + world_size = 4 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d() diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index d398c4365..e946a1f59 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -4,12 +4,23 @@ import time import torch + from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.core import global_context from colossalai.logging import get_dist_logger -from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, - VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, - VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) +from colossalai.nn import ( + Classifier3D, + CrossEntropyLoss3D, + Embedding3D, + LayerNorm3D, + Linear3D, + PatchEmbedding3D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier3D, + VocabParallelCrossEntropyLoss3D, + VocabParallelEmbedding3D, +) from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.utils import get_current_device, print_rank_0 @@ -20,7 +31,6 @@ def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -32,16 +42,16 @@ def check_linear(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True) + layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True) layer = layer.to(device) layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE) layer_master = layer_master.to(device) - weight_master = layer_master.weight.data.transpose(0, 1) + weight_master = layer_master.weight.data.transpose(0, 1).contiguous() torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[j] - weight = torch.chunk(weight, DEPTH, dim=-1)[i] + weight = torch.chunk(weight, DEPTH, dim=0)[i] layer.weight.data.copy_(weight) bias_master = layer_master.bias.data torch.distributed.broadcast(bias_master, src=0) @@ -49,7 +59,7 @@ def check_linear(): layer.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -72,7 +82,7 @@ def check_linear(): logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -94,7 +104,7 @@ def check_linear(): B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad @@ -108,7 +118,6 @@ def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -119,7 +128,7 @@ def check_layernorm(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype) + norm = LayerNorm3D(INPUT_SIZE, eps=1e-6) norm = norm.to(device) norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6) norm_master = norm_master.to(device) @@ -134,7 +143,7 @@ def check_layernorm(): norm.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -159,7 +168,7 @@ def check_layernorm(): logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -193,7 +202,6 @@ def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -204,10 +212,10 @@ def check_classifier_no_given_weight(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True) + layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True) layer = layer.to(device) - layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype) + layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True) layer_master = layer_master.to(device) weight_master = layer_master.weight.data @@ -219,7 +227,7 @@ def check_classifier_no_given_weight(): layer.bias.data.copy_(bias_master) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -242,7 +250,7 @@ def check_classifier_no_given_weight(): logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -283,7 +291,6 @@ def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -295,10 +302,10 @@ def check_vocab_parallel_classifier_no_given_weight(): k = global_context.get_local_rank(output_parallel_mode) layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -312,7 +319,7 @@ def check_vocab_parallel_classifier_no_given_weight(): layer.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -336,7 +343,7 @@ def check_vocab_parallel_classifier_no_given_weight(): logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -455,7 +462,6 @@ def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -466,10 +472,10 @@ def check_vocab_parallel_classifier_given_embed_weight(): k = global_context.get_local_rank(output_parallel_mode) embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) + embed = embed.to(device) embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) + embed_master = embed_master.to(device) weight_master = embed_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -479,10 +485,10 @@ def check_vocab_parallel_classifier_given_embed_weight(): embed.weight.data.copy_(weight) layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) A_shape = (BATCH_SIZE, SEQ_LENGTH) A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) @@ -504,7 +510,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -546,12 +552,12 @@ def check_patch_embed(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE) torch.nn.init.ones_(layer.cls_token) torch.nn.init.ones_(layer.pos_embed) layer = layer.to(device) - layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE) torch.nn.init.ones_(layer_master.cls_token) torch.nn.init.ones_(layer_master.pos_embed) layer_master = layer_master.to(device) @@ -566,7 +572,7 @@ def check_patch_embed(): layer.bias.data.copy_(proj_bias) A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = A_master.clone() @@ -586,7 +592,7 @@ def check_patch_embed(): logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -639,9 +645,9 @@ def check_embed(): k = global_context.get_local_rank(output_parallel_mode) layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -669,7 +675,7 @@ def check_embed(): logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -686,10 +692,7 @@ def check_embed(): B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - if j == k: - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) - else: - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None)) + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -709,9 +712,9 @@ def check_vocab_parallel_embed(): k = global_context.get_local_rank(output_parallel_mode) layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -741,7 +744,7 @@ def check_vocab_parallel_embed(): logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -771,7 +774,6 @@ def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -783,7 +785,7 @@ def check_loss(): criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) + out_master = torch.randn(out_shape, device=device) target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) @@ -836,7 +838,7 @@ def check_vocab_parallel_loss(): criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) + out_master = torch.randn(out_shape, device=device) target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py index 32ab63711..afb19c474 100644 --- a/tests/test_layers/test_3d/checks_3d/common.py +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -12,8 +12,8 @@ NUM_BLOCKS = 2 IMG_SIZE = 16 VOCAB_SIZE = 16 + def check_equal(A, B): eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) - assert eq - return eq - + assert eq, f"\nA = {A}\nB = {B}" + return eq \ No newline at end of file diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index c79dde2a1..29a8b3aea 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -10,9 +10,8 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, - check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, - check_vocab_parallel_classifier_given_embed_weight, +from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, + check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, check_vocab_parallel_loss) @@ -30,7 +29,6 @@ def check_layer(): check_layernorm() check_classifier_no_given_weight() check_vocab_parallel_classifier_no_given_weight() - check_classifier_given_embed_weight() check_vocab_parallel_classifier_given_embed_weight() check_embed() check_patch_embed() diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 37e8a4bab..d608ebf07 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -1,77 +1,75 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp - -from colossalai.nn import MoeLoss -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.engine.gradient_handler import MoeGradientHandler -from colossalai.context import MOE_CONTEXT -from colossalai.testing import assert_equal_in_group - -from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd -from tests.test_moe.test_moe_zero_init import MoeModel - - -@parameterize("enable_autocast", [False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - shard_strategy = shard_strategy_class() - - get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') - _, train_dataloader, _, optimizer_class, _ = get_components_func() - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - grad_handler.handle_gradient() - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_zero_model(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd + + +@parameterize("enable_autocast", [False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + shard_strategy = shard_strategy_class() + + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + grad_handler.handle_gradient() + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index da67b7610..9d9a7bd17 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,126 +1,124 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.amp import convert_to_apex_amp -from colossalai.nn import MoeLoss -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.utils import get_current_device -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.engine.gradient_handler import MoeGradientHandler -from colossalai.context import MOE_CONTEXT -from colossalai.testing import assert_equal_in_group - -from tests.test_zero.common import CONFIG, check_sharded_model_params -from tests.test_moe.test_moe_zero_init import MoeModel - - -def _run_step(model, optimizer, data, label, criterion, grad_handler): - model.train() - optimizer.zero_grad() - - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - - loss = loss.float() - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - - if grad_handler is not None: - grad_handler.handle_gradient() - - optimizer.step() - - -@parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug -@parameterize("reuse_fp16_shard", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, - shard_strategy_class, - use_cpuadam, - reuse_fp16_shard, - gpu_margin_mem_ratio=0.0): - shard_strategy = shard_strategy_class() - if use_cpuadam and cpu_offload is False: - return - MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') - _, train_dataloader, _, optimizer_class, _ = get_components_func() - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - - with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = MoeModel(checkpoint=True) - - zero_model = ShardedModelV2(zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=reuse_fp16_shard) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - apex_grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) - _run_step(zero_model, sharded_optim, data, label, criterion, None) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - _run_test_sharded_optim_v2() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_zero_optim(world_size=4) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_zero.common import CONFIG, check_sharded_model_params + + +def _run_step(model, optimizer, data, label, criterion, grad_handler): + model.train() + optimizer.zero_grad() + + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + + if grad_handler is not None: + grad_handler.handle_gradient() + + optimizer.step() + + +@parameterize("cpu_offload", [True]) +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("reuse_fp16_shard", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def _run_test_sharded_optim_v2(cpu_offload, + shard_strategy_class, + use_cpuadam, + reuse_fp16_shard, + gpu_margin_mem_ratio=0.0): + shard_strategy = shard_strategy_class() + if use_cpuadam and cpu_offload is False: + return + MOE_CONTEXT.reset_loss() + get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + _, train_dataloader, _, optimizer_class, _ = get_components_func() + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = MoeModel(checkpoint=True) + + zero_model = ShardedModelV2(zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=reuse_fp16_shard) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: + assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) + + model = MoeModel(checkpoint=True).half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + apex_grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) + _run_step(zero_model, sharded_optim, data, label, criterion, None) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=4) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index 64149b5a4..d317dc2e3 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -1,4 +1,5 @@ import math + import torch from colossalai.testing import parameterize @@ -65,11 +66,10 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): exp_avg_sq = torch.rand(p_data.shape) exp_avg_sq_copy = exp_avg_sq.clone() - try: - import cpu_adam - cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) - except: - raise ImportError("Import cpu adam error, please install colossal from source code") + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() + + cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) cpu_adam_op.step( step, @@ -114,3 +114,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") + + +if __name__ == '__main__': + test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 6e0aaf45f..7b9b6e9c4 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -1,8 +1,8 @@ -from numpy import dtype +import math + import torch import torch.nn as nn - -import math +from numpy import dtype from colossalai.testing import parameterize from colossalai.utils import multi_tensor_applier @@ -46,12 +46,11 @@ def torch_adam_update( @parameterize('p_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half]) def test_adam(adamw, step, p_dtype, g_dtype): - try: - import colossal_C - fused_adam = colossal_C.multi_tensor_adam - dummy_overflow_buf = torch.cuda.IntTensor([0]) - except: - raise ImportError("No colossal_C kernel installed.") + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + fused_adam = fused_optim.multi_tensor_adam + + dummy_overflow_buf = torch.cuda.IntTensor([0]) count = 0 @@ -71,7 +70,7 @@ def test_adam(adamw, step, p_dtype, g_dtype): weight_decay = 0 multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay) + True, weight_decay, -1) torch_adam_update( step, diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index fe0333bde..7ce2cd433 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -20,6 +20,33 @@ def color_debug(text, prefix=' ', color='blue'): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) +class MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x.sum() + +class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + self.dag_layer = nn.Linear(dim, dim, bias=False) + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + y = self.dag_layer(y) + return x.sum(), y.sum() class RpcTestModel(nn.Module): diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py new file mode 100644 index 000000000..c4dc617b1 --- /dev/null +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -0,0 +1,128 @@ +import torch +import pytest +import os +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware.adaptor import get_fx_topology +from rpc_test_utils import MLP, DAG_MLP +from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use + +# global variable for model created +batch_size = 16 +dim = 10 +rpc_is_initialized = _is_current_rpc_agent_set + +def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank+1] + +def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + +def run_master(model_cls, world_size, forward_only): + torch.manual_seed(100) + + epoch = 3 + device = 'cuda' + stage_num = world_size + chunk = 1 + num_microbatches = 8 + use_checkpoint = 'store_true' + + if model_cls == MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + kwargs = dict(x=x) + return kwargs + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + elif model_cls == DAG_MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + y = torch.zeros((batch_size, dim)) + kwargs = dict(x=x, y=y) + return kwargs + model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 + else: + pass + + data_kwargs = data_gen() + + engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint,) + if not forward_only: + engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + + for _ in range(epoch): + input_x = torch.randn((batch_size, dim), device=device) + input_y = torch.randn((batch_size, dim), device=device) + logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + +def run_worker(rank, model_cls, world_size, forward_only, master_func): + master_addr = 'localhost' + master_port = 29020 + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(master_port) + + disable_existing_loggers() + + launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device='cuda') + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(model_cls, world_size, forward_only) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + +@pytest.mark.skip("skip due to CI torch version 1.11") +@parameterize('model_cls', [MLP, DAG_MLP]) +@parameterize('forward_only', [True, False]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp_middleware_fwd(model_cls, forward_only): + world_size = 4 + master_func = run_master + mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) + +if __name__ == "__main__": + test_pp_middleware_fwd() \ No newline at end of file diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py index 5c5d06622..6b58aa801 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_tensor/common_utils/_utils.py @@ -1,11 +1,13 @@ import os import random + import numpy as np import torch import torch.distributed as dist -from colossalai.core import global_context as gpc + from colossalai.context import ParallelMode -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern +from colossalai.core import global_context as gpc +from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec def set_seed(seed): @@ -15,6 +17,7 @@ def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def check_equal(A, B): diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index 6f2ef9fa8..ad8ac87b2 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -1,21 +1,26 @@ -import pytest - from functools import partial -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed +import pytest import torch -from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print +from tests.test_tensor.common_utils import ( + debug_print, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_equal, + tensor_shard_equal, +) def init_1d_row_spec(model, pg: ProcessGroup): @@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp): torch_model.eval() set_seed(pg.dp_local_rank()) torch.distributed.barrier() - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - logits = model(colo_input, attn_mask) - torch_logits = torch_model(input_ids, attn_mask) + logits = model(colo_input) + torch_logits = torch_model(input_ids) assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" loss = criterion(logits, input_ids) torch_loss = criterion(torch_logits, input_ids) diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index c50393467..3f53b94e0 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -1,20 +1,25 @@ -import pytest from functools import partial + +import pytest import torch import torch.multiprocessing as mp -from colossalai.tensor.colo_parameter import ColoParameter import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.nn.optimizer import ColossalaiOptimizer - +from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \ - split_param_row_tp1d, split_param_col_tp1d +from tests.test_tensor.common_utils import ( + check_equal, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_shard_equal, +) def run_1d_hybrid_tp(model_name): @@ -112,7 +117,7 @@ def run_1d_hybrid_tp(model_name): else: output_torch = model_torch(data, label) loss_torch = output_torch - assert torch.allclose(loss, loss_torch, rtol=1e-2) + assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed" torch.distributed.barrier() loss.backward() @@ -169,7 +174,7 @@ def test_colo_optimizer(): get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() set_seed(1) - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) @@ -266,7 +271,7 @@ def _run_pretrain_load(): from transformers import BertForMaskedLM set_seed(1) model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = BertForMaskedLM.from_pretrained('bert-base-uncased') model_pretrained = model_pretrained.cuda() diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index a3eda1d8a..997b416f1 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -1,24 +1,28 @@ from copy import deepcopy -import pytest from functools import partial +import pytest import torch import torch.multiprocessing as mp -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec -from colossalai.nn.parallel.layers import init_colo_module, check_colo_module -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed - import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - -from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec - +from colossalai.nn.parallel.layers import check_colo_module, init_colo_module +from colossalai.tensor import ( + ColoTensor, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, + distspec, +) from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port - +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal def run_model_with_spec(mode, model_name): @@ -134,7 +138,7 @@ def run_linear_with_spec(mode): def run_check_shared_param(): - from transformers import BertForMaskedLM, BertConfig + from transformers import BertConfig, BertForMaskedLM hidden_dim = 8 num_head = 4 sequence_length = 12 @@ -153,7 +157,7 @@ def run_check_shared_param(): num_hidden_layers=num_layer, hidden_dropout_prob=0., attention_probs_dropout_prob=0.) - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = BertForMaskedLM(config) model = model.cuda() diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 8171ebfab..2f7aebed5 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -1,40 +1,69 @@ +from functools import partial + import pytest -from colossalai.utils.model.colo_init_context import ColoInitContext - import torch +import torch.multiprocessing as mp +import colossalai +from colossalai.tensor import ( + ColoParameter, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, +) +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed -@pytest.mark.skip -# FIXME(ver217): support lazy init -def test_lazy_init(): - in_dim = 4 - out_dim = 5 +def run_colo_init_context(rank: int, world_size: int, port: int): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - with ColoInitContext(lazy_memory_allocate=True) as ctx: - fc = torch.nn.Linear(in_dim, out_dim, bias=True) + # make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated. + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - # lazy_memory_allocate=True, no payload is maintained - assert fc.weight._torch_tensor.numel() == 0 + # keep parameters replicated during init + with ColoInitContext(device=get_current_device()): + model1 = model_builder() - fc.weight.torch_tensor() - assert fc.weight._torch_tensor.numel() == in_dim * out_dim + # shard the parameters during init + set_seed(42) + shard_spec = ReplicaSpec() + + # If using ShardSpec, the assertations will failed. + # But it is not a bug, the initialized values are not consist with the original one. + # shard_spec = ShardSpec(dims=[0], num_partitions=[world_size]) + default_pg = ProcessGroup(tp_degree=world_size) + with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec): + model2 = model_builder() + + # reshard both models + new_shard = ShardSpec(dims=[-1], num_partitions=[world_size]) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + p1: ColoParameter = p1 + p1.set_process_group(ProcessGroup(tp_degree=world_size)) + p1.set_dist_spec(new_shard) + p2.set_dist_spec(new_shard) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert (torch.allclose(p1, p2)) -@pytest.mark.skip -def test_device(): - in_dim = 4 - out_dim = 5 - - with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx: - fc = torch.nn.Linear(in_dim, out_dim, bias=True) - - # eval an lazy parameter - fc.weight.torch_tensor() - assert fc.weight.device == get_current_device() +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_colo_init_context(world_size): + run_func = partial(run_colo_init_context, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_lazy_init() - test_device() + test_colo_init_context(2) diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py new file mode 100644 index 000000000..c1ab30601 --- /dev/null +++ b/tests/test_tensor/test_mix_gather.py @@ -0,0 +1,333 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import mix_gather_simulator +from colossalai.utils import free_port + + +def check_mix_gather_S0S1(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0]) + b_target_pair = (b, [1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [4, 2] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S0S1(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [1]} + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S1S0(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [1]) + b_target_pair = (b, [0]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [2, 4] + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S1S0(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S01R(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0, 1]) + b_target_pair = (b, []) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S01R(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + + # DistSpec: + # shard_sequence: S01, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S1, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_RS01(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + + (f, b) = (0, 1) + f_target_pair = (f, []) + b_target_pair = (b, [0, 1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_RS01(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + + # DistSpec: + # shard_sequence: R, S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 8) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True) + + check_mix_gather_S0S1(device_mesh, rank) + + check_two_all_gather_S0S1(device_mesh, rank) + + check_mix_gather_S1S0(device_mesh, rank) + + check_two_all_gather_S1S0(device_mesh, rank) + + check_mix_gather_S01R(device_mesh, rank) + + check_two_all_gather_S01R(device_mesh, rank) + + check_mix_gather_RS01(device_mesh, rank) + + check_two_all_gather_RS01(device_mesh, rank) + + +@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +def test_mix_gather(): + world_size = 8 + run_func = partial(check_comm, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mix_gather() diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 7aedb0d5e..85008c67a 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -1,5 +1,4 @@ from functools import partial -from lib2to3 import pgen2 import pytest import torch diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index ad5a83e57..7e611e8a1 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -7,18 +7,16 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import ZeroDDP +from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal +from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec @@ -29,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) @@ -38,9 +34,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): "parameter '{}' has problem.".format(key) -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): +def run_fwd_bwd(model, criterion, optimizer, input_ids): optimizer.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optimizer.backward(loss) @@ -96,31 +92,35 @@ def run_gpt(placement_policy, tp_init_spec_func=None): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + model = GeminiDDP(model, init_device, placement_policy, True, False, 32) + # The same as the following 3 lines + # chunk_manager = ChunkManager(config_dict, init_device=init_device) + # gemini_manager = GeminiManager(placement_policy, chunk_manager) + # model = ZeroDDP(model, gemini_manager, pin_memory=True) + + zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) + # The same as the following 2 lines + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - print(chunk_manager) check_param(model, torch_model, pg) model.eval() torch_model.eval() set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) zero_optim.step() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py new file mode 100644 index 000000000..6d89fb90c --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.utils import build_checkpoints +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def test_global_model(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + global_state_dict = model_checkpoints[0] + assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) + for k, v in orig_state_dict.items(): + assert torch.equal(v, global_state_dict[k]) + + +def test_global_model_shard(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) + assert len(model_checkpoints) == 2 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) + assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 + for k, v in orig_state_dict.items(): + for state_dict in model_checkpoints: + if k in state_dict: + assert torch.equal(v, state_dict[k]) + + +def test_global_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) + assert len(optimizer_checkpoints) == 1 + assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} + for state in meta['paired_os'].values(): + for k, is_paired in state.items(): + if k == 'step': + assert not is_paired + else: + assert is_paired + orig_state_dict = optimizer.state_dict() + state_dict = optimizer_checkpoints[0] + for k, orig_state in orig_state_dict['state'].items(): + state = state_dict['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v2 == v2 + assert orig_state_dict['param_groups'] == state_dict['param_groups'] + + +def test_global_optimizer_shard(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) + assert len(optimizer_checkpoints) == 2 + assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] + orig_state_dict = optimizer.state_dict() + assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( + optimizer_checkpoints[1]['state'].keys()) + assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 + for k, orig_state in orig_state_dict['state'].items(): + state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ + 'state'] else optimizer_checkpoints[1]['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + + assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] + + +def test_dist_model_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] + assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] + dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + + +if __name__ == '__main__': + test_global_model() + test_global_model_shard() + test_global_optimizer() + test_global_optimizer_shard() + test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py new file mode 100644 index 000000000..780c13dc5 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -0,0 +1,188 @@ +from copy import deepcopy +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) +from torch import Tensor +from torch.nn import Module +from torch.optim import Adam, Optimizer + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): + with torch.no_grad(): + for p in model.parameters(): + p.fill_(scalar) + for state in optimizer.state.values(): + for v in state.values(): + if isinstance(v, Tensor): + v.fill_(scalar) + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) +def test_save_global_load_global(max_shard_size_gb: float): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) + new_model, new_optimizer = prepare_model_optim() + load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) + check_model_state_dict(model.state_dict(), new_model.state_dict()) + check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def launch_dist(fn, world_size: int): + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + + +def save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + reset_model_optim(model, optmizer) + world_size = dist.get_world_size() + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank]) + + +def load_and_check_dist(dir_name: str): + world_size = dist.get_world_size() + model, optmizer = prepare_model_optim(shard=True) + reset_model_optim(model, optmizer) + model_state_dict = deepcopy(model.state_dict()) + optimizer_state_dict = deepcopy(optmizer.state_dict()) + reset_model_optim(model, optmizer, 1) + load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size)) + check_model_state_dict(model_state_dict, model.state_dict()) + check_optim_state_dict(optimizer_state_dict, optmizer.state_dict()) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_global_load_dist(): + model, optimizer = prepare_model_optim() + reset_model_optim(model, optimizer) + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist_load_dist(): + with TemporaryDirectory() as dir_name: + # save tp + dp + fn = partial(save_dist, dir_name, False) + launch_dist(fn, 2) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + with TemporaryDirectory() as dir_name: + # save tp + zero + fn = partial(save_dist, dir_name, True) + launch_dist(fn, 4) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + launch_dist(fn, 4) + + +if __name__ == '__main__': + test_save_global_load_global(80 / 1024**3) + test_save_global_load_global(0) + test_save_global_load_dist() + test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py new file mode 100644 index 000000000..04e454dcb --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -0,0 +1,127 @@ +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import save, merge +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from tempfile import TemporaryDirectory +from torch.optim import Adam +from functools import partial +import torch +import os +import pytest +import colossalai +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_merge_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + dp_world_size = dist.get_world_size() // 2 + if not zero: + dist_metas = { + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + } + else: + dist_metas = { + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_merge_tp_dp(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 5 + global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 + meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) + assert meta['dist_meta'] is None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 20 + optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 + + +if __name__ == '__main__': + test_merge_global() + test_merge_tp_dp(False) + test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py new file mode 100644 index 000000000..5da2ae4fe --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge_param.py @@ -0,0 +1,101 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param + + +def test_unflatten_zero_param_even() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_unflatten_zero_param_uneven() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_row() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_col() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d_reverse() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_hybrid() -> None: + dist_metas = [ + ParamDistMeta(i % 2, + 2, + i // 2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_numel=4, + zero_orig_shape=[2, 2]) for i in range(12) + ] + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_dummy() -> None: + dist_metas = [ParamDistMeta(0, 1, 0, 1)] + orig_tensor = torch.rand(4, 6) + merged_tensor = merge_param([orig_tensor], dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +if __name__ == '__main__': + test_unflatten_zero_param_even() + test_unflatten_zero_param_uneven() + test_gather_tp_param_1d_row() + test_gather_tp_param_1d_col() + test_gather_tp_param_2d() + test_gather_tp_param_2d_reverse() + test_merge_param_hybrid() + test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py new file mode 100644 index 000000000..6e76f3167 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -0,0 +1,149 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import redist, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, + RedistMeta) +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +def check_checkpoint_shape(dir_name: str): + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + for meta_name in global_meta['meta']: + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta['dist_meta'] is not None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 10 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 + + +def test_global_to_dist(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + check_checkpoint_shape(output_dir) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank]) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_dist_to_dist(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + if not zero: + assert len(os.listdir(output_dir)) == 0 + else: + check_checkpoint_shape(output_dir) + + +if __name__ == '__main__': + test_global_to_dist() + test_dist_to_dist(False) + test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py new file mode 100644 index 000000000..5ff9d0aa2 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -0,0 +1,147 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from torch import Tensor +from torch.optim import Adam + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_overwrite(): + model = DummyModel() + with TemporaryDirectory() as dir_name: + with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: + pass + with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): + save(dir_name, model) + + +def test_save_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + assert len(os.listdir(dir_name)) == 5 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 1 + assert len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + check_model_state_dict(model.state_dict(), model_state_dict) + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) + other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) + assert len(other_state_dict) == 0 + + +def test_save_global_shard(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + assert len(os.listdir(dir_name)) == 7 + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 + model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] + assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 + check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) + optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] + assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 + assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] + check_optim_state_dict( + optimizer.state_dict(), { + 'state': { + **optimizer_state_dicts[0]['state'], + **optimizer_state_dicts[1]['state'] + }, + 'param_groups': optimizer_state_dicts[0]['param_groups'] + }) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def run_save_dist(dir_name): + model, optmizer = prepare_model_optim() + dist_metas = { + 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), + 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist(): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name) + world_size = 2 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + assert len(os.listdir(dir_name)) == 8 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 2 + for rank, meta_name in enumerate(global_meta['meta']): + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta.get('dist_meta', None) is not None + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict + + +if __name__ == '__main__': + test_overwrite() + test_save_global() + test_save_global_shard() + test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py new file mode 100644 index 000000000..8b83caa12 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py @@ -0,0 +1,137 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamRedistMeta +from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param + + +def test_flatten_zero_param_even() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_flatten_zero_param_uneven() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 + flat_tensors = flat_tensors[1:-1] + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 + unmerged_tensors = unmerged_tensors[1:-1] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_split_tp_param_1d_row() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_1d_col() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d_reverse() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_unmerge_param_hybrid() -> None: + redist_meta = ParamRedistMeta(2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_start_dp_rank=0, + zero_offsets=[0, 1]) + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 + for tp_rank in range(6): + for dp_rank in range(2): + assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) + + +def test_unmerge_param_dummy() -> None: + redist_meta = ParamRedistMeta(1, 1) + orig_tensor = torch.rand(4, 6) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 + assert torch.equal(orig_tensor, unmerged_tensors[0][0]) + + +if __name__ == '__main__': + test_flatten_zero_param_even() + test_flatten_zero_param_uneven() + test_split_tp_param_1d_row() + test_split_tp_param_1d_col() + test_split_tp_param_2d() + test_split_tp_param_2d_reverse() + test_unmerge_param_hybrid() + test_unmerge_param_dummy() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py new file mode 100644 index 000000000..58e3b21d9 --- /dev/null +++ b/tests/test_utils/test_flash_attention.py @@ -0,0 +1,146 @@ +import pytest +import torch +from einops import rearrange + +from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON + +if HAS_FLASH_ATTN: + from colossalai.kernel.cuda_native.flash_attention import ( + MaskedFlashAttention, + flash_attention_q_k_v, + flash_attention_q_kv, + flash_attention_qkv, + ) + +if HAS_TRITON: + from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention + +if HAS_MEM_EFF_ATTN: + from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention + + +def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + return ref_out + + +@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + + ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + tri_out = triton_flash_attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + + +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + + # reference implementation + ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # flash implementation + q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) + dout = rearrange(dout, 'z h n d -> (z n) h d').detach() + for i in range(3): + if i == 0: + tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True) + elif i == 1: + kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True) + else: + qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True) + + tri_out.backward(dout, retain_graph=True) + + if i == 0: + tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk, tri_dv)) + elif i == 1: + tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) + tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) + else: + tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) + tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) + + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + + +@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) +def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1) + + qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + attention_mask = torch.randint(2, (Z, H)).cuda().bool() + + out = attn(qkv, attention_mask) + + dout = torch.rand_like(out) + out.backward(dout) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) +def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) + + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + + out = attn(q, k, v, attention_mask=LowerTriangularMask()) + + dout = torch.rand_like(out) + out.backward(dout) + + +if __name__ == '__main__': + test_flash_attention(3, 4, 2, 16) diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py new file mode 100644 index 000000000..c23b3a3e8 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_grad_acc.py @@ -0,0 +1,167 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing.random import seed_all +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def exam_zero_1_2_grad_acc(): + local_rank = torch.distributed.get_rank() + seed_all(2009) + + # create model + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=32, + clip_grad_norm=1.0, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=32, + clip_grad_norm=1.0) + # create data + seed_all(2021 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data): + # zero-dp forward + zero1_output = zero1_model(cur_data) + zero2_output = zero2_model(cur_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1) + fwd_bwd_func(1, input_data2) + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_grad_acc(): + local_rank = torch.distributed.get_rank() + grad_scale = 32 + seed_all(2008) + + # create models + zero_model = TestModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda() + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=False, + initial_scale=grad_scale, + reduce_bucket_size=262144, + clip_grad_norm=1.0) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + # create data + seed_all(2022 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data, check_flag): + # zero-dp forward + zero_output = zero_model(cur_data) + + # torch-ddp forward + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + + # zero-dp backward + zero_optimizer.backward(zero_output.sum().float()) + # torch-ddp backward + torch_output.sum().backward() + + if check_flag: + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + unscale_grad = z1p.grad / grad_scale + # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) + assert torch.equal(p.grad, unscale_grad) + + zero_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) + + zero_optimizer.step() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) + assert_close(p.data, z1p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_grad_acc() + # exam_zero_1_2_grad_acc() + + +@pytest.mark.dist +def test_grad_accumulation(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_accumulation() diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py new file mode 100644 index 000000000..b02d3a6a4 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -0,0 +1,186 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing.random import seed_all +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def half_close(a, b, loose=False): + rtol = None + atol = None + if loose: + rtol = 5e-2 + atol = 5e-4 + + a = a.detach().half() + b = b.detach().half() + + assert_close(a, b, rtol=rtol, atol=atol) + + +def exam_zero_1_2(): + """ + In this test, we want to test whether zero stage 1 and 2 + deliver the same numerical results despite different communication + pattern + + we use these prefixes to differentiate the zero stage + oss: partition optimizer states + pg: partition gradients and optimizer states + + """ + local_rank = torch.distributed.get_rank() + seed_all(2001) + + # create model + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=128, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=128) + # create data + seed_all(2001 + local_rank) + input_data = torch.randn(32, 128).cuda() + + zero1_output = zero1_model(input_data) + zero2_output = zero2_model(input_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_torch_ddp(): + """ + In this test, two pairs of model and optimizers are created. + 1. zero: use sharded optimizer and fp16 parameters + 2. torch: use torch DDP and fp32 parameters + + We feed these two sets of models with the same input and check if the + differences in model output and updated parameters are within tolerance. + """ + local_rank = torch.distributed.get_rank() + seed_all(1453) + + # create models + zero_model = TestModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda().half() + # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + torch_model = torch_model.cuda() + + # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # half_close(p.data, z1p.data) + + # create optimizer + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=262144) + + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + seed_all(1453 + local_rank) + # create + input_data = torch.rand(32, 128).cuda() + + # zero-dp forward + zero_output = zero_model(input_data.half()) + + # torch-ddp forward + torch_output = torch_model(input_data) + half_close(zero_output, torch_output, loose=True) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) + + # torch-ddp backward + torch_output.mean().backward() + + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + half_close(p.grad, z1p.grad, loose=True) + + # zero-dp step + zero_optimizer.sync_grad() + zero_optimizer.step() + + # torch ddp step + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, torch.max(torch.abs(p.data - z1p.data))) + half_close(p.data, z1p.data, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_torch_ddp() + exam_zero_1_2() + + +@pytest.mark.dist +def test_zero_1_2(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_1_2() diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index b955e4852..0cba7a492 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -3,30 +3,34 @@ from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG + +import colossalai +from colossalai.gemini.memory_tracer.utils import colo_model_mem_usage from colossalai.logging import get_dist_logger from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.gemini.memory_tracer.model_data_memtracer import \ - colo_model_mem_usage from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from tests.components_to_test.registry import non_distributed_component_funcs -from common import CONFIG - @parameterize("init_device_type", ['cpu', 'cuda']) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_model_test(init_device_type, shard_strategy_class): logger = get_dist_logger("test_zero_init") - for get_components_func in non_distributed_component_funcs: + for name, get_components_func in non_distributed_component_funcs._registry.items(): + # because the ZeroInitContext automatically turns parameters to fp16 + # and the beit model use tensor.erfinv_() function to initialize weights + # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model + if name == 'beit': + continue model_builder, _, _, _, _ = get_components_func() if init_device_type == 'cuda': init_device = get_current_device() @@ -71,4 +75,4 @@ def test_zero_init_context(world_size): if __name__ == '__main__': - test_zero_init_context(4) + test_zero_init_context(1) diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py deleted file mode 100644 index bea971935..000000000 --- a/tests/test_zero/test_mem_collector.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -import torch.nn as nn -import torch.nn.functional as F -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from functools import partial - - -class MyTestModel(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.proj1 = nn.Linear(512, 512) - self.weight = nn.Parameter(torch.randn(1024, 512)) - self.proj2 = nn.Linear(1024, 512) - - def forward(self, x): - x = self.proj1(x) - x = F.linear(x, self.weight) - x = self.proj2(x) - - return x - - -def run_mem_collector_testing(): - cuda_capacity = colo_device_memory_capacity(get_current_device()) - fraction = (50 * 1024**2) / cuda_capacity - # limit max memory to 50MB - colo_set_process_memory_fraction(fraction) - shard_strategy = BucketTensorShardStrategy() - with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): - model = MyTestModel() - - model = ShardedModelV2(module=model, - shard_strategy=shard_strategy, - reduce_scatter_bucket_size_mb=1, - tensor_placement_policy='auto') - - data = torch.randn(2, 512, device=get_current_device()) - - output = model(data) - loss = torch.mean(output) - model.backward(loss) - - cuda_model_data_list = model._memstats_collector.model_data_list('cuda') - assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032] - - cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda') - assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1] - assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_mem_collector_testing() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_mem_collector(world_size=2): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_mem_collector() diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 654c82a46..95a9dee38 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -3,27 +3,27 @@ from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG, check_grads_padding, run_fwd_bwd +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True]) @parameterize("shard_strategy_class", [BucketTensorShardStrategy]) def run_model_test(enable_autocast, shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] shard_strategy = shard_strategy_class() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 2b42a7128..8fe7eb639 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -1,25 +1,25 @@ from functools import partial -import colossalai -from colossalai.utils.cuda import get_current_device import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from common import CONFIG, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import CONFIG, check_sharded_model_params def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @@ -45,7 +45,7 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: @@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g zero_model = ShardedModelV2( zero_model, shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', + tensor_placement_policy='cpu' if cpu_offload else 'auto', reuse_fp16_shard=use_cpuadam, ) diff --git a/version.txt b/version.txt index f115116c1..0ea3a944b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.11rc1 +0.2.0