From bd1ab9815813986d3c2d8c59ca739afc77e72979 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 9 Jun 2023 09:48:49 +0800 Subject: [PATCH 1/9] [gemini] fixed the gemini checkpoint io (#3934) --- colossalai/booster/plugin/gemini_plugin.py | 7 +++++-- colossalai/checkpoint_io/index_file.py | 18 ++++++++++-------- colossalai/zero/gemini/gemini_ddp.py | 5 ++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 46714fe1c..4a7efc165 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): save_state_dict(shard, checkpoint_file_path, use_safetensors) index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - logging.info(f"The model is going to be split to checkpoint shards. " + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 334ecbc04..a41cc482e 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,8 +1,8 @@ import json -from pathlib import Path -from typing import Any, List, Union import os -import json +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Union from .utils import is_dtensor_checkpoint @@ -22,8 +22,10 @@ class CheckpointIndexFile: def __init__(self, root_path=None) -> None: self.root_path = root_path - self.metadata: dict = dict() - self.weight_map: dict = dict() + + # use ordered dict to preserve the tensor checkpoint order + self.metadata: Dict = OrderedDict() + self.weight_map: Dict = OrderedDict() @staticmethod def from_file(index_path: Union[str, Path]): @@ -150,13 +152,13 @@ class CheckpointIndexFile: """ ckpt_path = self.weight_map[param_name] return ckpt_path - + def get_all_param_names(self): """ Get all the weight keys. """ return list(self.weight_map.keys()) - + def write_index_file(self, save_index_file): """ Write index file. @@ -164,5 +166,5 @@ class CheckpointIndexFile: save_index_file = os.path.join(self.root_path, save_index_file) index = {"metadata": self.metadata, "weight_map": self.weight_map} with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" + content = json.dumps(index, indent=2) + "\n" f.write(content) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 7e23fdb42..094320c4a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -716,7 +716,10 @@ class _StateDictSharder: tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 - if self.current_block_size + tensor_size > self.max_shard_size: + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: ret_block = self.current_block ret_block_size = self.current_block_size self.current_block = OrderedDict() From 4110d1f0d4baaf76be03fd449a9e724c48ff6eeb Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 12 Jun 2023 09:50:57 +0800 Subject: [PATCH 2/9] [workflow] cancel duplicated workflow jobs (#3960) --- .github/workflows/build_on_pr.yml | 9 +++++++++ .github/workflows/compatiblity_test_on_pr.yml | 6 ++++++ .github/workflows/doc_check_on_pr.yml | 6 ++++++ .github/workflows/doc_test_on_pr.yml | 6 ++++++ .github/workflows/example_check_on_pr.yml | 6 ++++++ 5 files changed, 33 insertions(+) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index a2807859b..fdcfd21d6 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -60,6 +60,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Copy testmon cache run: | # branch name may contain slash, we need to replace it with space @@ -83,6 +86,9 @@ jobs: changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }} anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }} runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 with: @@ -140,6 +146,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Checkout TensorNVMe uses: actions/checkout@v2 diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 94a723388..5098b8e36 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -12,6 +12,9 @@ jobs: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 - id: set-matrix @@ -40,6 +43,9 @@ jobs: image: ${{ matrix.container }} options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 timeout-minutes: 120 + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Install dependencies run: | diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 992cc93b0..848991bd3 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -16,6 +16,9 @@ jobs: github.event.pull_request.draft == false && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 @@ -31,6 +34,9 @@ jobs: github.event.pull_request.draft == false && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 with: diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index 325e2a7c9..2a07a2297 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -19,6 +19,9 @@ jobs: outputs: any_changed: ${{ steps.changed-files.outputs.any_changed }} changed_files: ${{ steps.changed-files.outputs.all_changed_files }} + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false name: Detect changed example files steps: - uses: actions/checkout@v3 @@ -59,6 +62,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Checkout ColossalAI-Documentation uses: actions/checkout@v2 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 31dbf7540..ee456c25f 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -20,6 +20,9 @@ jobs: matrix: ${{ steps.setup-matrix.outputs.matrix }} anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} name: Detect changed example files + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 with: @@ -77,6 +80,9 @@ jobs: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 10 + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 From 71fe52769cde5e8e8bd4c2703593d45193f35f3f Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 9 Jun 2023 09:48:49 +0800 Subject: [PATCH 3/9] [gemini] fixed the gemini checkpoint io (#3934) --- colossalai/booster/plugin/gemini_plugin.py | 7 +++++-- colossalai/checkpoint_io/index_file.py | 18 ++++++++++-------- colossalai/zero/gemini/gemini_ddp.py | 5 ++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 46714fe1c..4a7efc165 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): save_state_dict(shard, checkpoint_file_path, use_safetensors) index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - logging.info(f"The model is going to be split to checkpoint shards. " + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 334ecbc04..a41cc482e 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,8 +1,8 @@ import json -from pathlib import Path -from typing import Any, List, Union import os -import json +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Union from .utils import is_dtensor_checkpoint @@ -22,8 +22,10 @@ class CheckpointIndexFile: def __init__(self, root_path=None) -> None: self.root_path = root_path - self.metadata: dict = dict() - self.weight_map: dict = dict() + + # use ordered dict to preserve the tensor checkpoint order + self.metadata: Dict = OrderedDict() + self.weight_map: Dict = OrderedDict() @staticmethod def from_file(index_path: Union[str, Path]): @@ -150,13 +152,13 @@ class CheckpointIndexFile: """ ckpt_path = self.weight_map[param_name] return ckpt_path - + def get_all_param_names(self): """ Get all the weight keys. """ return list(self.weight_map.keys()) - + def write_index_file(self, save_index_file): """ Write index file. @@ -164,5 +166,5 @@ class CheckpointIndexFile: save_index_file = os.path.join(self.root_path, save_index_file) index = {"metadata": self.metadata, "weight_map": self.weight_map} with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" + content = json.dumps(index, indent=2) + "\n" f.write(content) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 7e23fdb42..094320c4a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -716,7 +716,10 @@ class _StateDictSharder: tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 - if self.current_block_size + tensor_size > self.max_shard_size: + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: ret_block = self.current_block ret_block_size = self.current_block_size self.current_block = OrderedDict() From 6718a2f2857ad9cc7210f867288c9f56ec3a9045 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 12 Jun 2023 09:50:57 +0800 Subject: [PATCH 4/9] [workflow] cancel duplicated workflow jobs (#3960) --- .github/workflows/build_on_pr.yml | 9 +++++++++ .github/workflows/compatiblity_test_on_pr.yml | 6 ++++++ .github/workflows/doc_check_on_pr.yml | 6 ++++++ .github/workflows/doc_test_on_pr.yml | 6 ++++++ .github/workflows/example_check_on_pr.yml | 6 ++++++ 5 files changed, 33 insertions(+) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 8b2253e57..513de40b7 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -60,6 +60,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Copy testmon cache run: | # branch name may contain slash, we need to replace it with space @@ -83,6 +86,9 @@ jobs: changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }} anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }} runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 with: @@ -140,6 +146,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Checkout TensorNVMe uses: actions/checkout@v2 diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 94a723388..5098b8e36 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -12,6 +12,9 @@ jobs: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 - id: set-matrix @@ -40,6 +43,9 @@ jobs: image: ${{ matrix.container }} options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 timeout-minutes: 120 + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Install dependencies run: | diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 992cc93b0..848991bd3 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -16,6 +16,9 @@ jobs: github.event.pull_request.draft == false && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 @@ -31,6 +34,9 @@ jobs: github.event.pull_request.draft == false && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v2 with: diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index 325e2a7c9..2a07a2297 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -19,6 +19,9 @@ jobs: outputs: any_changed: ${{ steps.changed-files.outputs.any_changed }} changed_files: ${{ steps.changed-files.outputs.all_changed_files }} + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false name: Detect changed example files steps: - uses: actions/checkout@v3 @@ -59,6 +62,9 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - name: Checkout ColossalAI-Documentation uses: actions/checkout@v2 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 31dbf7540..ee456c25f 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -20,6 +20,9 @@ jobs: matrix: ${{ steps.setup-matrix.outputs.matrix }} anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} name: Detect changed example files + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 with: @@ -77,6 +80,9 @@ jobs: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 10 + concurrency: + group: ${{ github.head_ref }} + cancel-in-progress: false steps: - uses: actions/checkout@v3 From 8bcad7367769633699c4ec5b6d94f2119ff44a68 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 13 Jun 2023 14:42:35 +0800 Subject: [PATCH 5/9] [workflow] fixed the directory check in build (#3980) --- .github/workflows/build_on_pr.yml | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 513de40b7..ac186a585 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -40,8 +40,8 @@ jobs: - name: Copy testmon cache run: | # branch name may contain slash, we need to replace it with space export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /") - if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then - [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ] && cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}" + if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ] && [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ]; then + cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}" fi env: MAIN_BRANCH: ${{ github.event.master_branch }} @@ -67,8 +67,8 @@ jobs: - name: Copy testmon cache run: | # branch name may contain slash, we need to replace it with space export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /") - if [ -d "/github/home/testmon_cache/${BASE}" ]; then - [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ] && mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER} + if [ -d "/github/home/testmon_cache/${BASE}" ] and [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then + mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER} fi env: PR_NUMBER: ${{ github.event.number }} @@ -159,7 +159,9 @@ jobs: - name: Restore TensorNVMe Cache run: | - [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe + if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then + cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe + fi - name: Install TensorNVMe run: | @@ -182,7 +184,9 @@ jobs: if: needs.detect.outputs.anyExtensionFileChanged != 'true' run: | # -p flag is required to preserve the file timestamp to avoid ninja rebuild - [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then + cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + fi - name: Install Colossal-AI run: | @@ -273,8 +277,8 @@ jobs: if: github.event.pull_request.merged == true run: | # branch name may contain slash, we need to replace it with space export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /") - if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ]; then - [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ] && cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/" + if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then + cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/" fi env: PR_NUMBER: ${{ github.event.pull_request.number }} From c9cff7e7fa34c6f1640141b3f3af2d08c1ec7534 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 15 Jun 2023 15:21:26 +0800 Subject: [PATCH 6/9] [checkpointio] General Checkpointing of Sharded Optimizers (#3984) --- colossalai/booster/plugin/gemini_plugin.py | 6 +- colossalai/booster/plugin/torch_ddp_plugin.py | 16 +- .../booster/plugin/torch_fsdp_plugin.py | 9 +- .../checkpoint_io/checkpoint_io_base.py | 12 +- .../checkpoint_io/general_checkpoint_io.py | 95 +++++++-- colossalai/checkpoint_io/index_file.py | 14 +- colossalai/checkpoint_io/utils.py | 185 +++++++++++++++++- .../test_general_checkpoint_io.py | 100 +++++++++- 8 files changed, 399 insertions(+), 38 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4a7efc165..ce01ad111 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict +from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device @@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO): model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ Save sharded model """ state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) - weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index b317ccf48..a18073db6 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): """ Save model to checkpoint but only on master process. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap if self.coordinator.is_master(): super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) @@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): model: nn.Module, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + """ + Save model to checkpoint but only on master process. + """ if self.coordinator.is_master(): - super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors) + super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) + + def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int): + """ + Save optimizer to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 8d534ea4c..ebd03b6ea 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,9 +1,9 @@ +import warnings from pathlib import Path from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn -import warnings from packaging import version from torch.distributed import ProcessGroup @@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to checkpoint but only on master process. @@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int): """ Save optimizer to checkpoint but only on master process. """ raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): """ Load optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index fbc8fc542..9d513043f 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -103,7 +103,7 @@ class CheckpointIO(ABC): checkpoint: str, shard: bool = False, gather_dtensor: bool = True, - variant: str = None, + prefix: str = None, size_per_shard: int = 1024, use_safetensors: bool = False): """ @@ -128,7 +128,7 @@ class CheckpointIO(ABC): multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. - variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. + prefix (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ @@ -137,11 +137,11 @@ class CheckpointIO(ABC): model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) + self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ Load optimizer from checkpoint. @@ -157,7 +157,7 @@ class CheckpointIO(ABC): if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path) + self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard) else: self.load_unsharded_optimizer(optimizer, checkpoint) @@ -218,7 +218,7 @@ class CheckpointIO(ABC): pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2cc9c3faa..d8e133313 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -11,15 +11,21 @@ from torch.optim import Optimizer from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - get_base_filenames, + get_model_base_filenames, + get_optimizer_base_filenames, get_shard_filename, has_index_file, is_safetensors_available, + load_param_groups_into_optimizer, load_shard_state_dict, load_state_dict, load_state_dict_into_model, + load_states_into_optimizer, + save_param_groups, save_state_dict, - shard_checkpoint, + shard_model_checkpoint, + shard_optimizer_checkpoint, + sharded_optimizer_loading_epilogue, ) __all__ = ['GeneralCheckpointIO'] @@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO): # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + """ + Load sharded optimizer with the given path to index file. + """ + optimizer.load_state_dict + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = load_state_dict(checkpoint) - optimizer.load_state_dict(checkpoint) + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory.') + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + load_states_into_optimizer(optimizer, state_dict, id_map) + del state_dict + gc.collect() + + sharded_optimizer_loading_epilogue(optimizer) def save_sharded_optimizer( self, @@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO): prefix: str, size_per_shard: int, ): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Offload optimizer states. States are broken into shards within max_shard_size. + state_dict = optimizer.state_dict() + sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( self, @@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO): model: nn.Module, checkpoint_path: str, gather_dtensor: bool = False, - variant: Optional[str] = None, + prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ @@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO): # shard checkpoint state_dict = model.state_dict() - state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size) - weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): @@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO): # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() missing_keys = [] for shard_file in checkpoint_files: diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index a41cc482e..388cf3fbe 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -111,7 +111,7 @@ class CheckpointIndexFile: return True return False - def get_checkpoint_fileanames(self) -> List[str]: + def get_checkpoint_filenames(self) -> List[str]: """ Get the set of checkpoint filenames in the weight map. @@ -159,6 +159,18 @@ class CheckpointIndexFile: """ return list(self.weight_map.keys()) + def get_param_group_filename(self) -> Union[str, None]: + """ + Get the file name of param_group file if this is a checkpoint for optimizer. + Returns: + str: param_group file name + """ + filename = self.metadata.get("param_groups", None) + if filename: + return str(self.root_path.joinpath(filename)) + else: + return None + def write_index_file(self, save_index_file): """ Write index file. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 435feda4a..21b70343b 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,17 +1,24 @@ # coding=utf-8 import re +from collections import abc as container_abcs +from collections import defaultdict +from itertools import chain from pathlib import Path from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch import torch.nn as nn +from torch.optim import Optimizer from colossalai.tensor.d_tensor.d_tensor import DTensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" +STATES_NAME = "pytorch_optim.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +STATES_INDEX_NAME = "pytorch_optim.bin.index.json" +GROUP_FILE_NAME = "pytorch_optim_group.bin" # ====================================== # General helper functions @@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It yield current_block, current_block_size +def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + """ + Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + + # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. + states = state_dict['state'] + + current_block = {} + current_block_size = 0 + + for param_id, state in states.items(): + + ret_block = None + ret_block_size = 0 + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + # If the states are stored as DTensors, mark isDTensor as true. + if type(state_tensor) == DTensor: + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + if not isDTensor: + + if current_block_size + state_size > max_shard_size: + ret_block = current_block + ret_block_size = current_block_size + current_block = {} + current_block_size = 0 + + current_block[param_id] = state + current_block_size += state_size + + if ret_block != None: + yield ret_block, ret_block_size + + yield current_block, current_block_size + + def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): """ load shard state dict into model @@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module, model.__class__.__name__, "\n\t".join(error_msgs))) +def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: + """ + Load information of param_groups into an initialized optimizer. + """ + + # Load list of param_groups from given file path. + # The params in saved_groups are in the form of integer indices. + saved_groups = torch.load(param_group_path) + if not isinstance(saved_groups, List): + raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') + + # The params in param_groups are in the form of pytorch tensors. + # For more details, please view source code of Optimizer class in pytorch. + param_groups = optimizer.param_groups + + # Check the compatibility of saved_groups and param_groups. + if len(param_groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of original parameter groups") + param_lens = (len(g['params']) for g in param_groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Creating mapping from id to parameters. + id_map = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in param_groups))) + } + + # Update parameter groups, setting their 'params' value. + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + + updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] + + optimizer.__dict__.update({'param_groups': updated_groups}) + return id_map + + +def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict): + r"""Copies states from `state_dict` into an Optimizer object. + + Args: + optimizer(Optimizer): An initialized Optimizer object to be loaded + state_dict(dict): a mapping from tensor index (an integer) + to its states to be loaded (a mapping from state name to a tensor). + id_map(dict): a mapping from tensor index (an integer) + to its corresponding parameter (a tensor) whose states will be updated. + """ + + def cast(param, value, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + if (key != "step"): + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v, key=k) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + new_states = defaultdict(dict) + for k, v in state_dict.items(): + if k in id_map: + param = id_map[k] + new_states[param] = cast(param, v) + else: + new_states[k] = v + + optimzier.state.update(new_states) + + +def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + # Do the cleaning up as in src code of Pytorch. + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault('differentiable', False) + + # ====================================== # Helper functions for saving state dict # ====================================== @@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors torch.save(state_dict, checkpoint_file_path) +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path): return torch.load(checkpoint_file_path) -def add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None and len(variant) > 0: +def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: + if prefix is not None and len(prefix) > 0: splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] + splits = splits[:-1] + [prefix] + splits[-1:] weights_name = ".".join(splits) return weights_name -def get_base_filenames(variant: str = None, use_safetensors: bool = False): +def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False): """ - generate base weight filenames + generate base model weight filenames """ weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) + weights_name = add_prefix(weights_name, prefix) save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = add_variant(save_index_file, variant) + save_index_file = add_prefix(save_index_file, prefix) return weights_name, save_index_file +def get_optimizer_base_filenames(prefix: str = None): + """ + generate base optimizer state filenames + """ + states_name = STATES_NAME + states_name = add_prefix(states_name, prefix) + + save_index_file = STATES_INDEX_NAME + save_index_file = add_prefix(save_index_file, prefix) + + param_group_file = GROUP_FILE_NAME + param_group_file = add_prefix(param_group_file, prefix) + + return states_name, save_index_file, param_group_file + + def get_shard_filename(weights_name: str, idx: int): """ get shard file name diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 9e973bb23..88e3673c1 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): @pytest.mark.parametrize('use_safetensors', [True, False]) -def test_sharded_checkpoint(use_safetensors: bool): +def test_sharded_model_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -100,3 +100,101 @@ def test_sharded_checkpoint(use_safetensors: bool): # check for model and optimizer state dict recursively check_state_dict_equal(model.state_dict(), new_model.state_dict()) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + +def test_sharded_optimizer_checkpoint(): + + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + # continue running fwd and bwd + for _ in range(5): + y = new_model(x) + loss = y.sum() + loss.backward() + new_optimizer.step() + + # save the newly got optimizer + ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create another new model + new_new_model = resnet18() + new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict()) + check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict()) + + +def test_sharded_optimizer_multiple_param_groups(): + + # create a model and optimizer + model = resnet18() + optimizer = Adam([{'params': model.layer1.parameters()}, \ + {'params': model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam([{'params': new_model.layer1.parameters()}, \ + {'params': new_model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) From 725af3eeeb16f7a348578e19105bed4f4096e0ca Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 15 Jun 2023 17:38:42 +0800 Subject: [PATCH 7/9] [booster] make optimizer argument optional for boost (#3993) * feat: make optimizer optional in Booster.boost * test: skip unet test if diffusers version > 0.10.2 --- colossalai/booster/booster.py | 8 +++--- .../booster/mixed_precision/fp16_torch.py | 8 +++--- .../mixed_precision/mixed_precision_base.py | 7 +++--- colossalai/booster/plugin/gemini_plugin.py | 18 +++++++------ .../booster/plugin/low_level_zero_plugin.py | 18 +++++++------ colossalai/booster/plugin/plugin_base.py | 12 ++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 13 +++++----- .../booster/plugin/torch_fsdp_plugin.py | 25 ++++++++++--------- .../test_autochunk_unet.py | 11 ++++++-- 9 files changed, 70 insertions(+), 50 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 4a42e2049..6e480d0db 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -97,10 +97,10 @@ class Booster: def boost( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ Boost the model, optimizer, criterion, lr_scheduler, and dataloader. diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 9999aa5e0..26fd92bd5 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision): def configure(self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model = TorchAMPModule(model) - optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) + if optimizer is not None: + optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) if criterion is not None: criterion = TorchAMPModule(criterion) return model, optimizer, criterion diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py index 2490e9811..8caa34e50 100644 --- a/colossalai/booster/mixed_precision/mixed_precision_base.py +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -15,7 +15,8 @@ class MixedPrecision(ABC): @abstractmethod def configure(self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: # TODO: implement this method pass diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ce01ad111..60b25b2c4 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -274,11 +274,11 @@ class GeminiPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): # convert model to sync bn @@ -293,8 +293,12 @@ class GeminiPlugin(DPPluginBase): # wrap the model with Gemini model = GeminiModel(model, self.gemini_config, self.verbose) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer(model.unwrap(), + optimizer, + self.zero_optim_config, + self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 2b312d0f9..94d722080 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.stage, self.precision) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): + optimizer = LowLevelZeroOptimizer(model.unwrap(), + optimizer, + self.zero_optim_config, + self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 561f58bc5..aa78f6827 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.optim import Optimizer @@ -38,11 +38,11 @@ class Plugin(ABC): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # implement this method pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index a18073db6..4bfd61af3 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -138,11 +138,11 @@ class TorchDDPPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # cast model to cuda model = model.cuda() @@ -152,7 +152,8 @@ class TorchDDPPlugin(DPPluginBase): # wrap the model with PyTorch DDP model = TorchDDPModel(model, **self.ddp_kwargs) - if not isinstance(optimizer, OptimizerWrapper): + if optimizer is not None and \ + not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index ebd03b6ea..abfffa9b0 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -195,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) - if len(optimizer.param_groups) > 1: - warnings.warn( - 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' - ) - optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) + if optimizer is not None: + if len(optimizer.param_groups) > 1: + warnings.warn( + 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' + ) + optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) - if not isinstance(optimizer, FSDPOptimizerWrapper): - optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) + if not isinstance(optimizer, FSDPOptimizerWrapper): + optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) return fsdp_model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index ff0d4a1b5..fc9d8455e 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -4,12 +4,15 @@ import pytest import torch try: - from diffusers import UNet2DModel - MODELS = [UNet2DModel] + import diffusers + MODELS = [diffusers.UNet2DModel] HAS_REPO = True + from packaging import version + SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2") except: MODELS = [] HAS_REPO = False + SKIP_UNET_TEST = False from test_autochunk_diffuser_utils import run_test @@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args +@pytest.mark.skipif( + SKIP_UNET_TEST, + reason="diffusers version > 0.10.2", +) @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", From 822c3d4d66d2d74cb7c7080abed6a207602dddfd Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 16 Jun 2023 14:14:05 +0800 Subject: [PATCH 8/9] [checkpointio] sharded optimizer checkpoint for DDP plugin (#4002) --- colossalai/booster/booster.py | 49 ++++++++++++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 10 ++-- .../checkpoint_io/checkpoint_io_base.py | 8 +-- .../checkpoint_io/general_checkpoint_io.py | 10 +++- colossalai/checkpoint_io/utils.py | 16 +++++- .../test_torch_ddp_checkpoint_io.py | 20 ++++---- 6 files changed, 79 insertions(+), 34 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 6e480d0db..cee547b33 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.interface import ModelWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -165,11 +166,11 @@ class Booster: assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' return self.plugin.no_sync(model) - def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): """Load model from checkpoint. Args: - model (nn.Module): A model boosted by Booster. + model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. strict (bool, optional): whether to strictly enforce that the keys @@ -179,24 +180,34 @@ class Booster: self.checkpoint_io.load_model(model, checkpoint, strict) def save_model(self, - model: nn.Module, + model: Union[nn.Module, ModelWrapper], checkpoint: str, - prefix: str = None, shard: bool = False, - size_per_shard: int = 1024): + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False): """Save model to checkpoint. Args: - model (nn.Module): A model boosted by Booster. + model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It is a file path if ``shard=False``. Otherwise, it is a directory path. - prefix (str, optional): A prefix added to parameter and buffer - names to compose the keys in state_dict. Defaults to None. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. """ - self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard) + self.checkpoint_io.save_model(model, + checkpoint=checkpoint, + shard=shard, + gather_dtensor=gather_dtensor, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """Load optimizer from checkpoint. @@ -205,12 +216,21 @@ class Booster: optimizer (Optimizer): An optimizer boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) - def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): - """Save optimizer to checkpoint. - Warning: Saving sharded optimizer checkpoint is not supported yet. + def save_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + """ + Save optimizer to checkpoint. Args: optimizer (Optimizer): An optimizer boosted by Booster. @@ -218,9 +238,12 @@ class Booster: It is a file path if ``shard=False``. Otherwise, it is a directory path. shard (bool, optional): Whether to save checkpoint a sharded way. If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ - self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) + self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """Save lr scheduler to checkpoint. diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 4bfd61af3..71b435155 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def save_sharded_model(self, model: nn.Module, checkpoint_path: str, - gather_dtensor: bool = False, + gather_dtensor: bool = True, prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): @@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) - def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): """ Save optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9d513043f..8ff9d87c2 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -148,6 +148,9 @@ class CheckpointIO(ABC): Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ index_file_exists, index_file_path = has_index_file(checkpoint) @@ -157,7 +160,7 @@ class CheckpointIO(ABC): if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard) + self.load_sharded_optimizer(optimizer, index_file_path, prefix) else: self.load_unsharded_optimizer(optimizer, checkpoint) @@ -251,7 +254,7 @@ class CheckpointIO(ABC): # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load optimizer from sharded checkpoint. @@ -259,7 +262,6 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. - size_per_shard (int): size per shard in MB. """ pass diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d8e133313..26cafcada 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,6 +8,8 @@ from typing import Iterator, Optional, OrderedDict, Tuple import torch.nn as nn from torch.optim import Optimizer +from colossalai.interface import OptimizerWrapper + from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -50,11 +52,15 @@ class GeneralCheckpointIO(CheckpointIO): # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load sharded optimizer with the given path to index file. """ - optimizer.load_state_dict + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = optimizer.optim + # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 21b70343b..3dada00cd 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -139,6 +139,12 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> state_size = 0 isDTensor = False for state_tensor in state.values(): + + # When state_tensor is None (e.g., a SGD optimizer with momentum set to 0), + # The calculation of tensor size should be skipped to avoid error. + if state_tensor is None: + continue + # If the states are stored as DTensors, mark isDTensor as true. if type(state_tensor) == DTensor: isDTensor = True @@ -271,7 +277,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str return id_map -def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): r"""Copies states from `state_dict` into an Optimizer object. Args: @@ -311,10 +317,16 @@ def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: d else: new_states[k] = v - optimzier.state.update(new_states) + optimizer.state.update(new_states) def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + r"""Do the cleaning up work after state_dict has been loaded into optimizer + + Args: + optimizer(Optimizer): An optimizer object whose state has just been loaded. + """ + # Do the cleaning up as in src code of Pytorch. optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. optimizer.defaults.setdefault('differentiable', False) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 5501ee4e3..14332b5b3 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad @parameterize('shard', [True, False]) -def check_torch_ddp_checkpointIO(shard: bool): +@parameterize('size_per_shard', [16, 128]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" - booster.save_model(model, model_ckpt_path, shard=shard) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) - booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) dist.barrier() new_model = resnet18() @@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) - booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) - check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) def run_dist(rank, world_size, port): From a5883aa7909070480d218b62ff8f3e987e7eebd8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Jun 2023 18:23:02 +0800 Subject: [PATCH 9/9] [test] fixed codefactor format report (#4026) --- .../test_general_checkpoint_io.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 88e3673c1..0976d4503 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -165,8 +165,13 @@ def test_sharded_optimizer_multiple_param_groups(): # create a model and optimizer model = resnet18() - optimizer = Adam([{'params': model.layer1.parameters()}, \ - {'params': model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + optimizer = Adam([{ + 'params': model.layer1.parameters() + }, { + 'params': model.layer2.parameters(), + 'lr': 0.002 + }], + lr=0.001) # create test data sample x = torch.randn(1, 3, 224, 224) @@ -189,8 +194,13 @@ def test_sharded_optimizer_multiple_param_groups(): # create new model new_model = resnet18() - new_optimizer = Adam([{'params': new_model.layer1.parameters()}, \ - {'params': new_model.layer2.parameters(), 'lr': 0.002}], lr=0.001) + new_optimizer = Adam([{ + 'params': new_model.layer1.parameters() + }, { + 'params': new_model.layer2.parameters(), + 'lr': 0.002 + }], + lr=0.001) ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))