mirror of https://github.com/hpcaitech/ColossalAI
commit
ca768eb62d
|
@ -40,8 +40,8 @@ jobs:
|
||||||
- name: Copy testmon cache
|
- name: Copy testmon cache
|
||||||
run: | # branch name may contain slash, we need to replace it with space
|
run: | # branch name may contain slash, we need to replace it with space
|
||||||
export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
|
export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
|
||||||
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then
|
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ] && [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ]; then
|
||||||
[ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ] && cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
|
cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
|
||||||
fi
|
fi
|
||||||
env:
|
env:
|
||||||
MAIN_BRANCH: ${{ github.event.master_branch }}
|
MAIN_BRANCH: ${{ github.event.master_branch }}
|
||||||
|
@ -60,12 +60,15 @@ jobs:
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- name: Copy testmon cache
|
- name: Copy testmon cache
|
||||||
run: | # branch name may contain slash, we need to replace it with space
|
run: | # branch name may contain slash, we need to replace it with space
|
||||||
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
||||||
if [ -d "/github/home/testmon_cache/${BASE}" ]; then
|
if [ -d "/github/home/testmon_cache/${BASE}" ] and [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then
|
||||||
[ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ] && mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
|
mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
|
||||||
fi
|
fi
|
||||||
env:
|
env:
|
||||||
PR_NUMBER: ${{ github.event.number }}
|
PR_NUMBER: ${{ github.event.number }}
|
||||||
|
@ -83,6 +86,9 @@ jobs:
|
||||||
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
|
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
|
||||||
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
|
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
|
@ -140,6 +146,9 @@ jobs:
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout TensorNVMe
|
- name: Checkout TensorNVMe
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
@ -150,7 +159,9 @@ jobs:
|
||||||
|
|
||||||
- name: Restore TensorNVMe Cache
|
- name: Restore TensorNVMe Cache
|
||||||
run: |
|
run: |
|
||||||
[ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
|
if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then
|
||||||
|
cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Install TensorNVMe
|
- name: Install TensorNVMe
|
||||||
run: |
|
run: |
|
||||||
|
@ -173,7 +184,9 @@ jobs:
|
||||||
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
|
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
|
||||||
run: |
|
run: |
|
||||||
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
||||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then
|
||||||
|
cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Install Colossal-AI
|
- name: Install Colossal-AI
|
||||||
run: |
|
run: |
|
||||||
|
@ -264,8 +277,8 @@ jobs:
|
||||||
if: github.event.pull_request.merged == true
|
if: github.event.pull_request.merged == true
|
||||||
run: | # branch name may contain slash, we need to replace it with space
|
run: | # branch name may contain slash, we need to replace it with space
|
||||||
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
||||||
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ]; then
|
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
|
||||||
[ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ] && cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
|
cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
|
||||||
fi
|
fi
|
||||||
env:
|
env:
|
||||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
|
|
@ -12,6 +12,9 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- id: set-matrix
|
- id: set-matrix
|
||||||
|
@ -40,6 +43,9 @@ jobs:
|
||||||
image: ${{ matrix.container }}
|
image: ${{ matrix.container }}
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -16,6 +16,9 @@ jobs:
|
||||||
github.event.pull_request.draft == false &&
|
github.event.pull_request.draft == false &&
|
||||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
@ -31,6 +34,9 @@ jobs:
|
||||||
github.event.pull_request.draft == false &&
|
github.event.pull_request.draft == false &&
|
||||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
|
|
|
@ -19,6 +19,9 @@ jobs:
|
||||||
outputs:
|
outputs:
|
||||||
any_changed: ${{ steps.changed-files.outputs.any_changed }}
|
any_changed: ${{ steps.changed-files.outputs.any_changed }}
|
||||||
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
|
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
name: Detect changed example files
|
name: Detect changed example files
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
|
@ -59,6 +62,9 @@ jobs:
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout ColossalAI-Documentation
|
- name: Checkout ColossalAI-Documentation
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
|
@ -20,6 +20,9 @@ jobs:
|
||||||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||||
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
|
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
|
||||||
name: Detect changed example files
|
name: Detect changed example files
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
|
@ -77,6 +80,9 @@ jobs:
|
||||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.head_ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
|
from colossalai.interface import ModelWrapper
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
|
@ -97,10 +98,10 @@ class Booster:
|
||||||
def boost(
|
def boost(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||||
"""
|
"""
|
||||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||||
|
@ -165,11 +166,11 @@ class Booster:
|
||||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||||
return self.plugin.no_sync(model)
|
return self.plugin.no_sync(model)
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): A model boosted by Booster.
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
strict (bool, optional): whether to strictly enforce that the keys
|
strict (bool, optional): whether to strictly enforce that the keys
|
||||||
|
@ -179,24 +180,34 @@ class Booster:
|
||||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||||
|
|
||||||
def save_model(self,
|
def save_model(self,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, ModelWrapper],
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
prefix: str = None,
|
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
size_per_shard: int = 1024):
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024,
|
||||||
|
use_safetensors: bool = False):
|
||||||
"""Save model to checkpoint.
|
"""Save model to checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): A model boosted by Booster.
|
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
|
||||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||||
|
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
|
self.checkpoint_io.save_model(model,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
shard=shard,
|
||||||
|
gather_dtensor=gather_dtensor,
|
||||||
|
prefix=prefix,
|
||||||
|
size_per_shard=size_per_shard,
|
||||||
|
use_safetensors=use_safetensors)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||||
"""Load optimizer from checkpoint.
|
"""Load optimizer from checkpoint.
|
||||||
|
@ -205,12 +216,21 @@ class Booster:
|
||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
def save_optimizer(self,
|
||||||
"""Save optimizer to checkpoint.
|
optimizer: Optimizer,
|
||||||
Warning: Saving sharded optimizer checkpoint is not supported yet.
|
checkpoint: str,
|
||||||
|
shard: bool = False,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
|
@ -218,9 +238,12 @@ class Booster:
|
||||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||||
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""Save lr scheduler to checkpoint.
|
"""Save lr scheduler to checkpoint.
|
||||||
|
|
|
@ -115,10 +115,12 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
||||||
|
|
||||||
def configure(self,
|
def configure(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
criterion: Optional[Callable] = None,
|
||||||
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
model = TorchAMPModule(model)
|
model = TorchAMPModule(model)
|
||||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
if optimizer is not None:
|
||||||
|
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||||
if criterion is not None:
|
if criterion is not None:
|
||||||
criterion = TorchAMPModule(criterion)
|
criterion = TorchAMPModule(criterion)
|
||||||
return model, optimizer, criterion
|
return model, optimizer, criterion
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure(self,
|
def configure(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
criterion: Optional[Callable] = None,
|
||||||
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||||
# TODO: implement this method
|
# TODO: implement this method
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
|
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
model: GeminiDDP,
|
model: GeminiDDP,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
Save sharded model
|
Save sharded model
|
||||||
"""
|
"""
|
||||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
total_size = 0
|
total_size = 0
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
for idx, shard_pair in enumerate(state_dict_shard):
|
for idx, shard_pair in enumerate(state_dict_shard):
|
||||||
|
@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
|
||||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
# only save the index file on the master rank
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
index_file.write_index_file(save_index_file)
|
||||||
|
logging.info(f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}.")
|
f"index located at {save_index_file}.")
|
||||||
|
|
||||||
|
@ -271,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
# convert model to sync bn
|
# convert model to sync bn
|
||||||
|
@ -290,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
|
||||||
# wrap the model with Gemini
|
# wrap the model with Gemini
|
||||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = GeminiOptimizer(model.unwrap(),
|
||||||
|
optimizer,
|
||||||
|
self.zero_optim_config,
|
||||||
|
self.optim_kwargs,
|
||||||
self.verbose)
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = LowLevelZeroOptimizer(model.unwrap(),
|
||||||
|
optimizer,
|
||||||
|
self.zero_optim_config,
|
||||||
|
self.optim_kwargs,
|
||||||
self.verbose)
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Iterator, List, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -38,11 +38,11 @@ class Plugin(ABC):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
# implement this method
|
# implement this method
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
|
@ -53,12 +52,27 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
def save_sharded_model(self,
|
def save_sharded_model(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = True,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
|
"""
|
||||||
|
Save model to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||||
|
|
||||||
|
def save_sharded_optimizer(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
size_per_shard: int = 1024):
|
||||||
|
"""
|
||||||
|
Save optimizer to checkpoint but only on master process.
|
||||||
|
"""
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
@ -128,11 +142,11 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
# cast model to cuda
|
# cast model to cuda
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
|
@ -142,7 +156,8 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
# wrap the model with PyTorch DDP
|
# wrap the model with PyTorch DDP
|
||||||
model = TorchDDPModel(model, **self.ddp_kwargs)
|
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and \
|
||||||
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = OptimizerWrapper(optimizer)
|
optimizer = OptimizerWrapper(optimizer)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||||
size_per_shard: int, use_safetensors: bool):
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
|
@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||||
|
|
||||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||||
|
size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||||
"""
|
"""
|
||||||
Load optimizer to checkpoint but only on master process.
|
Load optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
@ -194,23 +195,24 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optional[Optimizer] = None,
|
||||||
criterion: Callable = None,
|
criterion: Optional[Callable] = None,
|
||||||
dataloader: DataLoader = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: LRScheduler = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
# wrap the model with PyTorch FSDP
|
# wrap the model with PyTorch FSDP
|
||||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||||
|
|
||||||
if len(optimizer.param_groups) > 1:
|
if optimizer is not None:
|
||||||
warnings.warn(
|
if len(optimizer.param_groups) > 1:
|
||||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
warnings.warn(
|
||||||
)
|
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
)
|
||||||
|
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||||
|
|
||||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||||
|
|
||||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ class CheckpointIO(ABC):
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
variant: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -128,7 +128,7 @@ class CheckpointIO(ABC):
|
||||||
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
||||||
that the checkpoint path is a directory path instead of a file path.
|
that the checkpoint path is a directory path instead of a file path.
|
||||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
|
@ -137,17 +137,20 @@ class CheckpointIO(ABC):
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||||
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
|
|
||||||
|
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
|
||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
@ -218,7 +221,7 @@ class CheckpointIO(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||||
size_per_shard: int, use_safetensors: bool):
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to sharded checkpoint.
|
Save model to sharded checkpoint.
|
||||||
|
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
|
||||||
# ========================================================
|
# ========================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
|
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
size_per_shard (int): size per shard in MB.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -8,18 +8,26 @@ from typing import Iterator, Optional, OrderedDict, Tuple
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_base_filenames,
|
get_model_base_filenames,
|
||||||
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
has_index_file,
|
has_index_file,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
|
load_states_into_optimizer,
|
||||||
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
shard_checkpoint,
|
shard_model_checkpoint,
|
||||||
|
shard_optimizer_checkpoint,
|
||||||
|
sharded_optimizer_loading_epilogue,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ['GeneralCheckpointIO']
|
__all__ = ['GeneralCheckpointIO']
|
||||||
|
@ -44,12 +52,34 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
# save the checkpoint
|
# save the checkpoint
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
"""
|
||||||
|
Load sharded optimizer with the given path to index file.
|
||||||
|
"""
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
# If optimizer is wrapped, unwrap it.
|
||||||
checkpoint = load_state_dict(checkpoint)
|
if isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer = optimizer.optim
|
||||||
|
|
||||||
|
# Read checkpoint index file.
|
||||||
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
|
# Load param_groups
|
||||||
|
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||||
|
if param_group_path is None:
|
||||||
|
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||||
|
Lacking param group file under current directory.')
|
||||||
|
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||||
|
|
||||||
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -59,7 +89,54 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
size_per_shard: int,
|
size_per_shard: int,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
"""
|
||||||
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
|
The following files will be created under the path:
|
||||||
|
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||||
|
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||||
|
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||||
|
"""
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Offload optimizer states. States are broken into shards within max_shard_size.
|
||||||
|
state_dict = optimizer.state_dict()
|
||||||
|
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||||
|
|
||||||
|
# Preparing file paths and index file.
|
||||||
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||||
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
|
||||||
|
# Store the information of param groups to param_group_file.
|
||||||
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
|
save_param_groups(state_dict, group_file_path)
|
||||||
|
|
||||||
|
# Save shards of optimizer states.
|
||||||
|
total_size = 0
|
||||||
|
for idx, shard_pair in enumerate(sharded_state):
|
||||||
|
shard, current_size = shard_pair
|
||||||
|
shard_file = get_shard_filename(states_name, idx)
|
||||||
|
total_size = total_size + current_size
|
||||||
|
for param_id in shard.keys():
|
||||||
|
index_file.append_weight_map(str(param_id), shard_file)
|
||||||
|
|
||||||
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||||
|
|
||||||
|
# Wrap up index file.
|
||||||
|
index_file.append_meta_data("total_size", total_size)
|
||||||
|
index_file.write_index_file(save_index_file)
|
||||||
|
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||||
|
f"You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}.")
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||||
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
self,
|
self,
|
||||||
|
@ -74,7 +151,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
variant: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
|
@ -89,9 +166,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
# shard checkpoint
|
# shard checkpoint
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||||
|
|
||||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
total_size = 0
|
total_size = 0
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
for idx, shard_pair in enumerate(state_dict_shard):
|
for idx, shard_pair in enumerate(state_dict_shard):
|
||||||
|
@ -128,7 +205,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
|
|
||||||
# read checkpoint index file
|
# read checkpoint index file
|
||||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, List, Union
|
|
||||||
import os
|
import os
|
||||||
import json
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from .utils import is_dtensor_checkpoint
|
from .utils import is_dtensor_checkpoint
|
||||||
|
|
||||||
|
@ -22,8 +22,10 @@ class CheckpointIndexFile:
|
||||||
|
|
||||||
def __init__(self, root_path=None) -> None:
|
def __init__(self, root_path=None) -> None:
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.metadata: dict = dict()
|
|
||||||
self.weight_map: dict = dict()
|
# use ordered dict to preserve the tensor checkpoint order
|
||||||
|
self.metadata: Dict = OrderedDict()
|
||||||
|
self.weight_map: Dict = OrderedDict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file(index_path: Union[str, Path]):
|
def from_file(index_path: Union[str, Path]):
|
||||||
|
@ -109,7 +111,7 @@ class CheckpointIndexFile:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_checkpoint_fileanames(self) -> List[str]:
|
def get_checkpoint_filenames(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the set of checkpoint filenames in the weight map.
|
Get the set of checkpoint filenames in the weight map.
|
||||||
|
|
||||||
|
@ -150,13 +152,25 @@ class CheckpointIndexFile:
|
||||||
"""
|
"""
|
||||||
ckpt_path = self.weight_map[param_name]
|
ckpt_path = self.weight_map[param_name]
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
def get_all_param_names(self):
|
def get_all_param_names(self):
|
||||||
"""
|
"""
|
||||||
Get all the weight keys.
|
Get all the weight keys.
|
||||||
"""
|
"""
|
||||||
return list(self.weight_map.keys())
|
return list(self.weight_map.keys())
|
||||||
|
|
||||||
|
def get_param_group_filename(self) -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Get the file name of param_group file if this is a checkpoint for optimizer.
|
||||||
|
Returns:
|
||||||
|
str: param_group file name
|
||||||
|
"""
|
||||||
|
filename = self.metadata.get("param_groups", None)
|
||||||
|
if filename:
|
||||||
|
return str(self.root_path.joinpath(filename))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def write_index_file(self, save_index_file):
|
def write_index_file(self, save_index_file):
|
||||||
"""
|
"""
|
||||||
Write index file.
|
Write index file.
|
||||||
|
@ -164,5 +178,5 @@ class CheckpointIndexFile:
|
||||||
save_index_file = os.path.join(self.root_path, save_index_file)
|
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2) + "\n"
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
|
@ -1,17 +1,24 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import re
|
import re
|
||||||
|
from collections import abc as container_abcs
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
STATES_NAME = "pytorch_optim.bin"
|
||||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||||
|
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# General helper functions
|
# General helper functions
|
||||||
|
@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving shard file
|
# Helper functions for saving shard file
|
||||||
# ======================================
|
# ======================================
|
||||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
"""
|
"""
|
||||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
given size.
|
given size.
|
||||||
|
@ -110,6 +117,56 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
||||||
yield current_block, current_block_size
|
yield current_block, current_block_size
|
||||||
|
|
||||||
|
|
||||||
|
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
|
"""
|
||||||
|
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||||
|
states = state_dict['state']
|
||||||
|
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
for param_id, state in states.items():
|
||||||
|
|
||||||
|
ret_block = None
|
||||||
|
ret_block_size = 0
|
||||||
|
|
||||||
|
# A state might contain more than one tensors.
|
||||||
|
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||||
|
state_size = 0
|
||||||
|
isDTensor = False
|
||||||
|
for state_tensor in state.values():
|
||||||
|
|
||||||
|
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||||
|
# The calculation of tensor size should be skipped to avoid error.
|
||||||
|
if state_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the states are stored as DTensors, mark isDTensor as true.
|
||||||
|
if type(state_tensor) == DTensor:
|
||||||
|
isDTensor = True
|
||||||
|
state_size += calculate_tensor_size(state_tensor)
|
||||||
|
|
||||||
|
if not isDTensor:
|
||||||
|
|
||||||
|
if current_block_size + state_size > max_shard_size:
|
||||||
|
ret_block = current_block
|
||||||
|
ret_block_size = current_block_size
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block[param_id] = state
|
||||||
|
current_block_size += state_size
|
||||||
|
|
||||||
|
if ret_block != None:
|
||||||
|
yield ret_block, ret_block_size
|
||||||
|
|
||||||
|
yield current_block, current_block_size
|
||||||
|
|
||||||
|
|
||||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
load shard state dict into model
|
load shard state dict into model
|
||||||
|
@ -179,6 +236,102 @@ def load_state_dict_into_model(model: nn.Module,
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
||||||
|
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Load information of param_groups into an initialized optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load list of param_groups from given file path.
|
||||||
|
# The params in saved_groups are in the form of integer indices.
|
||||||
|
saved_groups = torch.load(param_group_path)
|
||||||
|
if not isinstance(saved_groups, List):
|
||||||
|
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||||
|
|
||||||
|
# The params in param_groups are in the form of pytorch tensors.
|
||||||
|
# For more details, please view source code of Optimizer class in pytorch.
|
||||||
|
param_groups = optimizer.param_groups
|
||||||
|
|
||||||
|
# Check the compatibility of saved_groups and param_groups.
|
||||||
|
if len(param_groups) != len(saved_groups):
|
||||||
|
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||||
|
param_lens = (len(g['params']) for g in param_groups)
|
||||||
|
saved_lens = (len(g['params']) for g in saved_groups)
|
||||||
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||||
|
raise ValueError("loaded state dict contains a parameter group "
|
||||||
|
"that doesn't match the size of optimizer's group")
|
||||||
|
|
||||||
|
# Creating mapping from id to parameters.
|
||||||
|
id_map = {
|
||||||
|
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||||
|
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update parameter groups, setting their 'params' value.
|
||||||
|
def update_group(group, new_group):
|
||||||
|
new_group['params'] = group['params']
|
||||||
|
return new_group
|
||||||
|
|
||||||
|
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||||
|
|
||||||
|
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||||
|
return id_map
|
||||||
|
|
||||||
|
|
||||||
|
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||||
|
r"""Copies states from `state_dict` into an Optimizer object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||||
|
state_dict(dict): a mapping from tensor index (an integer)
|
||||||
|
to its states to be loaded (a mapping from state name to a tensor).
|
||||||
|
id_map(dict): a mapping from tensor index (an integer)
|
||||||
|
to its corresponding parameter (a tensor) whose states will be updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cast(param, value, key=None):
|
||||||
|
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Floating-point types are a bit special here. They are the only ones
|
||||||
|
# that are assumed to always match the type of params.
|
||||||
|
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||||
|
if (key != "step"):
|
||||||
|
if param.is_floating_point():
|
||||||
|
value = value.to(param.dtype)
|
||||||
|
value = value.to(param.device)
|
||||||
|
return value
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||||
|
elif isinstance(value, container_abcs.Iterable):
|
||||||
|
return type(value)(cast(param, v) for v in value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||||
|
# State that is not assigned to params is copied as is (needed for
|
||||||
|
# backward compatibility).
|
||||||
|
new_states = defaultdict(dict)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k in id_map:
|
||||||
|
param = id_map[k]
|
||||||
|
new_states[param] = cast(param, v)
|
||||||
|
else:
|
||||||
|
new_states[k] = v
|
||||||
|
|
||||||
|
optimizer.state.update(new_states)
|
||||||
|
|
||||||
|
|
||||||
|
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||||
|
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Do the cleaning up as in src code of Pytorch.
|
||||||
|
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||||
|
optimizer.defaults.setdefault('differentiable', False)
|
||||||
|
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
|
@ -203,6 +356,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
torch.save(state_dict, checkpoint_file_path)
|
torch.save(state_dict, checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Save information of param_groups to given file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): state dict.
|
||||||
|
group_file_path (str): path to the group file.
|
||||||
|
"""
|
||||||
|
param_groups = state_dict["param_groups"]
|
||||||
|
torch.save(param_groups, group_file_path)
|
||||||
|
|
||||||
|
|
||||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||||
|
@ -392,28 +557,44 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||||
return torch.load(checkpoint_file_path)
|
return torch.load(checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||||
if variant is not None and len(variant) > 0:
|
if prefix is not None and len(prefix) > 0:
|
||||||
splits = weights_name.split(".")
|
splits = weights_name.split(".")
|
||||||
splits = splits[:-1] + [variant] + splits[-1:]
|
splits = splits[:-1] + [prefix] + splits[-1:]
|
||||||
weights_name = ".".join(splits)
|
weights_name = ".".join(splits)
|
||||||
|
|
||||||
return weights_name
|
return weights_name
|
||||||
|
|
||||||
|
|
||||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
generate base weight filenames
|
generate base model weight filenames
|
||||||
"""
|
"""
|
||||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||||
weights_name = add_variant(weights_name, variant)
|
weights_name = add_prefix(weights_name, prefix)
|
||||||
|
|
||||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||||
save_index_file = add_variant(save_index_file, variant)
|
save_index_file = add_prefix(save_index_file, prefix)
|
||||||
|
|
||||||
return weights_name, save_index_file
|
return weights_name, save_index_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_base_filenames(prefix: str = None):
|
||||||
|
"""
|
||||||
|
generate base optimizer state filenames
|
||||||
|
"""
|
||||||
|
states_name = STATES_NAME
|
||||||
|
states_name = add_prefix(states_name, prefix)
|
||||||
|
|
||||||
|
save_index_file = STATES_INDEX_NAME
|
||||||
|
save_index_file = add_prefix(save_index_file, prefix)
|
||||||
|
|
||||||
|
param_group_file = GROUP_FILE_NAME
|
||||||
|
param_group_file = add_prefix(param_group_file, prefix)
|
||||||
|
|
||||||
|
return states_name, save_index_file, param_group_file
|
||||||
|
|
||||||
|
|
||||||
def get_shard_filename(weights_name: str, idx: int):
|
def get_shard_filename(weights_name: str, idx: int):
|
||||||
"""
|
"""
|
||||||
get shard file name
|
get shard file name
|
||||||
|
|
|
@ -716,7 +716,10 @@ class _StateDictSharder:
|
||||||
tensor_size = calculate_tensor_size(tensor)
|
tensor_size = calculate_tensor_size(tensor)
|
||||||
ret_block = None
|
ret_block = None
|
||||||
ret_block_size = 0
|
ret_block_size = 0
|
||||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
|
||||||
|
# before we return the current block and create a new block,
|
||||||
|
# we need to ensure that the current block is not empty
|
||||||
|
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||||
ret_block = self.current_block
|
ret_block = self.current_block
|
||||||
ret_block_size = self.current_block_size
|
ret_block_size = self.current_block_size
|
||||||
self.current_block = OrderedDict()
|
self.current_block = OrderedDict()
|
||||||
|
|
|
@ -4,12 +4,15 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import UNet2DModel
|
import diffusers
|
||||||
MODELS = [UNet2DModel]
|
MODELS = [diffusers.UNet2DModel]
|
||||||
HAS_REPO = True
|
HAS_REPO = True
|
||||||
|
from packaging import version
|
||||||
|
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
||||||
except:
|
except:
|
||||||
MODELS = []
|
MODELS = []
|
||||||
HAS_REPO = False
|
HAS_REPO = False
|
||||||
|
SKIP_UNET_TEST = False
|
||||||
|
|
||||||
from test_autochunk_diffuser_utils import run_test
|
from test_autochunk_diffuser_utils import run_test
|
||||||
|
|
||||||
|
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||||
return meta_args, concrete_args
|
return meta_args, concrete_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
SKIP_UNET_TEST,
|
||||||
|
reason="diffusers version > 0.10.2",
|
||||||
|
)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||||
reason="torch version is lower than 1.12.0",
|
reason="torch version is lower than 1.12.0",
|
||||||
|
|
|
@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('use_safetensors', [True, False])
|
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||||
def test_sharded_checkpoint(use_safetensors: bool):
|
def test_sharded_model_checkpoint(use_safetensors: bool):
|
||||||
# create a model and optimizer
|
# create a model and optimizer
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
optimizer = Adam(model.parameters(), lr=0.001)
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
@ -100,3 +100,111 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
||||||
# check for model and optimizer state dict recursively
|
# check for model and optimizer state dict recursively
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharded_optimizer_checkpoint():
|
||||||
|
|
||||||
|
# create a model and optimizer
|
||||||
|
model = resnet18()
|
||||||
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# create test data sample
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
# run fwd and bwd
|
||||||
|
y = model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# create temp directories for checkpoint
|
||||||
|
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# save the model and optimizer
|
||||||
|
ckpt_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
|
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create new model
|
||||||
|
new_model = resnet18()
|
||||||
|
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
||||||
|
# continue running fwd and bwd
|
||||||
|
for _ in range(5):
|
||||||
|
y = new_model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
new_optimizer.step()
|
||||||
|
|
||||||
|
# save the newly got optimizer
|
||||||
|
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create another new model
|
||||||
|
new_new_model = resnet18()
|
||||||
|
new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict())
|
||||||
|
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharded_optimizer_multiple_param_groups():
|
||||||
|
|
||||||
|
# create a model and optimizer
|
||||||
|
model = resnet18()
|
||||||
|
optimizer = Adam([{
|
||||||
|
'params': model.layer1.parameters()
|
||||||
|
}, {
|
||||||
|
'params': model.layer2.parameters(),
|
||||||
|
'lr': 0.002
|
||||||
|
}],
|
||||||
|
lr=0.001)
|
||||||
|
|
||||||
|
# create test data sample
|
||||||
|
x = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
# run fwd and bwd
|
||||||
|
y = model(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# create temp directories for checkpoint
|
||||||
|
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# save the model and optimizer
|
||||||
|
ckpt_io = GeneralCheckpointIO()
|
||||||
|
|
||||||
|
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||||
|
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||||
|
|
||||||
|
# create new model
|
||||||
|
new_model = resnet18()
|
||||||
|
new_optimizer = Adam([{
|
||||||
|
'params': new_model.layer1.parameters()
|
||||||
|
}, {
|
||||||
|
'params': new_model.layer2.parameters(),
|
||||||
|
'lr': 0.002
|
||||||
|
}],
|
||||||
|
lr=0.001)
|
||||||
|
|
||||||
|
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||||
|
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||||
|
|
||||||
|
# check for model and optimizer state dict recursively
|
||||||
|
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||||
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
||||||
|
|
||||||
|
|
||||||
@parameterize('shard', [True, False])
|
@parameterize('shard', [True, False])
|
||||||
def check_torch_ddp_checkpointIO(shard: bool):
|
@parameterize('size_per_shard', [16, 128])
|
||||||
|
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
|
@ -38,11 +39,9 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
if not shard:
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
|
||||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
new_model = resnet18()
|
new_model = resnet18()
|
||||||
|
@ -55,11 +54,10 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path)
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||||
|
|
||||||
if not shard:
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
Loading…
Reference in New Issue