mirror of https://github.com/hpcaitech/ColossalAI
commit
ca768eb62d
|
@ -40,8 +40,8 @@ jobs:
|
|||
- name: Copy testmon cache
|
||||
run: | # branch name may contain slash, we need to replace it with space
|
||||
export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
|
||||
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then
|
||||
[ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ] && cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
|
||||
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ] && [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ]; then
|
||||
cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
|
||||
fi
|
||||
env:
|
||||
MAIN_BRANCH: ${{ github.event.master_branch }}
|
||||
|
@ -60,12 +60,15 @@ jobs:
|
|||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Copy testmon cache
|
||||
run: | # branch name may contain slash, we need to replace it with space
|
||||
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
||||
if [ -d "/github/home/testmon_cache/${BASE}" ]; then
|
||||
[ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ] && mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
|
||||
if [ -d "/github/home/testmon_cache/${BASE}" ] and [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then
|
||||
mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
|
||||
fi
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
|
@ -83,6 +86,9 @@ jobs:
|
|||
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
|
||||
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
|
@ -140,6 +146,9 @@ jobs:
|
|||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Checkout TensorNVMe
|
||||
uses: actions/checkout@v2
|
||||
|
@ -150,7 +159,9 @@ jobs:
|
|||
|
||||
- name: Restore TensorNVMe Cache
|
||||
run: |
|
||||
[ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
|
||||
if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then
|
||||
cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
|
||||
fi
|
||||
|
||||
- name: Install TensorNVMe
|
||||
run: |
|
||||
|
@ -173,7 +184,9 @@ jobs:
|
|||
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
|
||||
run: |
|
||||
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
|
||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then
|
||||
cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
fi
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
|
@ -264,8 +277,8 @@ jobs:
|
|||
if: github.event.pull_request.merged == true
|
||||
run: | # branch name may contain slash, we need to replace it with space
|
||||
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
|
||||
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ]; then
|
||||
[ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ] && cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
|
||||
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
|
||||
cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
|
||||
fi
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
|
|
|
@ -12,6 +12,9 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- id: set-matrix
|
||||
|
@ -40,6 +43,9 @@ jobs:
|
|||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
||||
timeout-minutes: 120
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
|
|
@ -16,6 +16,9 @@ jobs:
|
|||
github.event.pull_request.draft == false &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
|
@ -31,6 +34,9 @@ jobs:
|
|||
github.event.pull_request.draft == false &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
|
|
|
@ -19,6 +19,9 @@ jobs:
|
|||
outputs:
|
||||
any_changed: ${{ steps.changed-files.outputs.any_changed }}
|
||||
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
name: Detect changed example files
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -59,6 +62,9 @@ jobs:
|
|||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- name: Checkout ColossalAI-Documentation
|
||||
uses: actions/checkout@v2
|
||||
|
|
|
@ -20,6 +20,9 @@ jobs:
|
|||
matrix: ${{ steps.setup-matrix.outputs.matrix }}
|
||||
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
|
||||
name: Detect changed example files
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
|
@ -77,6 +80,9 @@ jobs:
|
|||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
concurrency:
|
||||
group: ${{ github.head_ref }}
|
||||
cancel-in-progress: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
|
@ -97,10 +98,10 @@ class Booster:
|
|||
def boost(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
"""
|
||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
|
@ -165,11 +166,11 @@ class Booster:
|
|||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
|
@ -179,24 +180,34 @@ class Booster:
|
|||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""Save model to checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
model (nn.Module or ModelWrapper): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
||||
"""
|
||||
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
|
||||
self.checkpoint_io.save_model(model,
|
||||
checkpoint=checkpoint,
|
||||
shard=shard,
|
||||
gather_dtensor=gather_dtensor,
|
||||
prefix=prefix,
|
||||
size_per_shard=size_per_shard,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""Load optimizer from checkpoint.
|
||||
|
@ -205,12 +216,21 @@ class Booster:
|
|||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
"""Save optimizer to checkpoint.
|
||||
Warning: Saving sharded optimizer checkpoint is not supported yet.
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
|
@ -218,9 +238,12 @@ class Booster:
|
|||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
|
|
@ -115,9 +115,11 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
|||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
if optimizer is not None:
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if criterion is not None:
|
||||
criterion = TorchAMPModule(criterion)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
|
|||
@abstractmethod
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
model: GeminiDDP,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
|
@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
|
@ -271,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
|
@ -290,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
|
|||
# wrap the model with Gemini
|
||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(),
|
||||
optimizer,
|
||||
self.zero_optim_config,
|
||||
self.optim_kwargs,
|
||||
self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = LowLevelZeroOptimizer(model.unwrap(),
|
||||
optimizer,
|
||||
self.zero_optim_config,
|
||||
self.optim_kwargs,
|
||||
self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -38,11 +38,11 @@ class Plugin(ABC):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# implement this method
|
||||
pass
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
|
@ -53,12 +52,27 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
@ -128,11 +142,11 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# cast model to cuda
|
||||
model = model.cuda()
|
||||
|
||||
|
@ -142,7 +156,8 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
# wrap the model with PyTorch DDP
|
||||
model = TorchDDPModel(model, **self.ddp_kwargs)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
@ -194,15 +195,16 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None,
|
||||
dataloader: DataLoader = None,
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
|
|
|
@ -103,7 +103,7 @@ class CheckpointIO(ABC):
|
|||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
variant: str = None,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
|
@ -128,7 +128,7 @@ class CheckpointIO(ABC):
|
|||
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
||||
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
|
@ -137,17 +137,20 @@ class CheckpointIO(ABC):
|
|||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
|
|||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
|
@ -218,7 +221,7 @@ class CheckpointIO(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
|
|||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
|
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
|
|||
optimizer (Optimizer): optimizer to be loaded.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -8,18 +8,26 @@ from typing import Iterator, Optional, OrderedDict, Tuple
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
get_base_filenames,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
shard_checkpoint,
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
@ -44,12 +52,34 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.optim
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
|
@ -59,7 +89,54 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
prefix: str,
|
||||
size_per_shard: int,
|
||||
):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Offload optimizer states. States are broken into shards within max_shard_size.
|
||||
state_dict = optimizer.state_dict()
|
||||
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
|
@ -74,7 +151,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
|
@ -89,9 +166,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
|
@ -128,7 +205,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
import os
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
|
@ -22,8 +22,10 @@ class CheckpointIndexFile:
|
|||
|
||||
def __init__(self, root_path=None) -> None:
|
||||
self.root_path = root_path
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
# use ordered dict to preserve the tensor checkpoint order
|
||||
self.metadata: Dict = OrderedDict()
|
||||
self.weight_map: Dict = OrderedDict()
|
||||
|
||||
@staticmethod
|
||||
def from_file(index_path: Union[str, Path]):
|
||||
|
@ -109,7 +111,7 @@ class CheckpointIndexFile:
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_checkpoint_fileanames(self) -> List[str]:
|
||||
def get_checkpoint_filenames(self) -> List[str]:
|
||||
"""
|
||||
Get the set of checkpoint filenames in the weight map.
|
||||
|
||||
|
@ -157,6 +159,18 @@ class CheckpointIndexFile:
|
|||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def get_param_group_filename(self) -> Union[str, None]:
|
||||
"""
|
||||
Get the file name of param_group file if this is a checkpoint for optimizer.
|
||||
Returns:
|
||||
str: param_group file name
|
||||
"""
|
||||
filename = self.metadata.get("param_groups", None)
|
||||
if filename:
|
||||
return str(self.root_path.joinpath(filename))
|
||||
else:
|
||||
return None
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Write index file.
|
||||
|
@ -164,5 +178,5 @@ class CheckpointIndexFile:
|
|||
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
content = json.dumps(index, indent=2) + "\n"
|
||||
f.write(content)
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
# coding=utf-8
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
STATES_NAME = "pytorch_optim.bin"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||
|
||||
# ======================================
|
||||
# General helper functions
|
||||
|
@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
@ -110,6 +117,56 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
|||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""
|
||||
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
|
||||
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
|
||||
states = state_dict['state']
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
for param_id, state in states.items():
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if state_tensor is None:
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if type(state_tensor) == DTensor:
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
|
@ -179,6 +236,102 @@ def load_state_dict_into_model(model: nn.Module,
|
|||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
|
||||
"""
|
||||
Load information of param_groups into an initialized optimizer.
|
||||
"""
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
# The params in param_groups are in the form of pytorch tensors.
|
||||
# For more details, please view source code of Optimizer class in pytorch.
|
||||
param_groups = optimizer.param_groups
|
||||
|
||||
# Check the compatibility of saved_groups and param_groups.
|
||||
if len(param_groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of original parameter groups")
|
||||
param_lens = (len(g['params']) for g in param_groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Creating mapping from id to parameters.
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in param_groups)))
|
||||
}
|
||||
|
||||
# Update parameter groups, setting their 'params' value.
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
|
||||
|
||||
optimizer.__dict__.update({'param_groups': updated_groups})
|
||||
return id_map
|
||||
|
||||
|
||||
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||
r"""Copies states from `state_dict` into an Optimizer object.
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An initialized Optimizer object to be loaded
|
||||
state_dict(dict): a mapping from tensor index (an integer)
|
||||
to its states to be loaded (a mapping from state name to a tensor).
|
||||
id_map(dict): a mapping from tensor index (an integer)
|
||||
to its corresponding parameter (a tensor) whose states will be updated.
|
||||
"""
|
||||
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
new_states = defaultdict(dict)
|
||||
for k, v in state_dict.items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
new_states[param] = cast(param, v)
|
||||
else:
|
||||
new_states[k] = v
|
||||
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
r"""Do the cleaning up work after state_dict has been loaded into optimizer
|
||||
|
||||
Args:
|
||||
optimizer(Optimizer): An optimizer object whose state has just been loaded.
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault('differentiable', False)
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
|
@ -203,6 +356,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
|||
torch.save(state_dict, checkpoint_file_path)
|
||||
|
||||
|
||||
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||
"""
|
||||
Save information of param_groups to given file path.
|
||||
|
||||
Args:
|
||||
state_dict (dict): state dict.
|
||||
group_file_path (str): path to the group file.
|
||||
"""
|
||||
param_groups = state_dict["param_groups"]
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
|
||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||
"""
|
||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||
|
@ -392,28 +557,44 @@ def load_state_dict(checkpoint_file_path: Path):
|
|||
return torch.load(checkpoint_file_path)
|
||||
|
||||
|
||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None and len(variant) > 0:
|
||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||
if prefix is not None and len(prefix) > 0:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
splits = splits[:-1] + [prefix] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
||||
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
generate base model weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
weights_name = add_prefix(weights_name, prefix)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
|
||||
def get_optimizer_base_filenames(prefix: str = None):
|
||||
"""
|
||||
generate base optimizer state filenames
|
||||
"""
|
||||
states_name = STATES_NAME
|
||||
states_name = add_prefix(states_name, prefix)
|
||||
|
||||
save_index_file = STATES_INDEX_NAME
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
param_group_file = GROUP_FILE_NAME
|
||||
param_group_file = add_prefix(param_group_file, prefix)
|
||||
|
||||
return states_name, save_index_file, param_group_file
|
||||
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
get shard file name
|
||||
|
|
|
@ -716,7 +716,10 @@ class _StateDictSharder:
|
|||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
|
|
|
@ -4,12 +4,15 @@ import pytest
|
|||
import torch
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
MODELS = [UNet2DModel]
|
||||
import diffusers
|
||||
MODELS = [diffusers.UNet2DModel]
|
||||
HAS_REPO = True
|
||||
from packaging import version
|
||||
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
SKIP_UNET_TEST = False
|
||||
|
||||
from test_autochunk_diffuser_utils import run_test
|
||||
|
||||
|
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
|||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_UNET_TEST,
|
||||
reason="diffusers version > 0.10.2",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||
def test_sharded_checkpoint(use_safetensors: bool):
|
||||
def test_sharded_model_checkpoint(use_safetensors: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
@ -100,3 +100,111 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
|||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_checkpoint():
|
||||
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# create test data sample
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# run fwd and bwd
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# create temp directories for checkpoint
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# save the model and optimizer
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
# continue running fwd and bwd
|
||||
for _ in range(5):
|
||||
y = new_model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
new_optimizer.step()
|
||||
|
||||
# save the newly got optimizer
|
||||
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create another new model
|
||||
new_new_model = resnet18()
|
||||
new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict())
|
||||
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_sharded_optimizer_multiple_param_groups():
|
||||
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
optimizer = Adam([{
|
||||
'params': model.layer1.parameters()
|
||||
}, {
|
||||
'params': model.layer2.parameters(),
|
||||
'lr': 0.002
|
||||
}],
|
||||
lr=0.001)
|
||||
|
||||
# create test data sample
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# run fwd and bwd
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# create temp directories for checkpoint
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# save the model and optimizer
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
|
||||
|
||||
# create new model
|
||||
new_model = resnet18()
|
||||
new_optimizer = Adam([{
|
||||
'params': new_model.layer1.parameters()
|
||||
}, {
|
||||
'params': new_model.layer2.parameters(),
|
||||
'lr': 0.002
|
||||
}],
|
||||
lr=0.001)
|
||||
|
||||
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
|
||||
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
|
||||
|
||||
# check for model and optimizer state dict recursively
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
|||
|
||||
|
||||
@parameterize('shard', [True, False])
|
||||
def check_torch_ddp_checkpointIO(shard: bool):
|
||||
@parameterize('size_per_shard', [16, 128])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
|
@ -38,10 +39,8 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
|
@ -55,7 +54,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
|
|
Loading…
Reference in New Issue