Merge pull request #4025 from hpcaitech/develop

[sync] sync develop to main
pull/4033/head
Frank Lee 2023-06-19 10:31:34 +08:00 committed by GitHub
commit ca768eb62d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 615 additions and 134 deletions

View File

@ -40,8 +40,8 @@ jobs:
- name: Copy testmon cache
run: | # branch name may contain slash, we need to replace it with space
export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then
[ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ] && cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ] && [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ]; then
cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
fi
env:
MAIN_BRANCH: ${{ github.event.master_branch }}
@ -60,12 +60,15 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Copy testmon cache
run: | # branch name may contain slash, we need to replace it with space
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
if [ -d "/github/home/testmon_cache/${BASE}" ]; then
[ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ] && mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
if [ -d "/github/home/testmon_cache/${BASE}" ] and [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then
mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
fi
env:
PR_NUMBER: ${{ github.event.number }}
@ -83,6 +86,9 @@ jobs:
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:
@ -140,6 +146,9 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Checkout TensorNVMe
uses: actions/checkout@v2
@ -150,7 +159,9 @@ jobs:
- name: Restore TensorNVMe Cache
run: |
[ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then
cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
fi
- name: Install TensorNVMe
run: |
@ -173,7 +184,9 @@ jobs:
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then
cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
fi
- name: Install Colossal-AI
run: |
@ -264,8 +277,8 @@ jobs:
if: github.event.pull_request.merged == true
run: | # branch name may contain slash, we need to replace it with space
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ]; then
[ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ] && cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
fi
env:
PR_NUMBER: ${{ github.event.pull_request.number }}

View File

@ -12,6 +12,9 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3
- id: set-matrix
@ -40,6 +43,9 @@ jobs:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
timeout-minutes: 120
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Install dependencies
run: |

View File

@ -16,6 +16,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2
@ -31,6 +34,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:

View File

@ -19,6 +19,9 @@ jobs:
outputs:
any_changed: ${{ steps.changed-files.outputs.any_changed }}
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
name: Detect changed example files
steps:
- uses: actions/checkout@v3
@ -59,6 +62,9 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Checkout ColossalAI-Documentation
uses: actions/checkout@v2

View File

@ -20,6 +20,9 @@ jobs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3
with:
@ -77,6 +80,9 @@ jobs:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3

View File

@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -97,10 +98,10 @@ class Booster:
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
@ -165,11 +166,11 @@ class Booster:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.
Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys
@ -179,24 +180,34 @@ class Booster:
self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self,
model: nn.Module,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""Save model to checkpoint.
Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
self.checkpoint_io.save_model(model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.
@ -205,12 +216,21 @@ class Booster:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
"""Save optimizer to checkpoint.
Warning: Saving sharded optimizer checkpoint is not supported yet.
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
@ -218,9 +238,12 @@ class Booster:
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Save lr scheduler to checkpoint.

View File

@ -115,9 +115,11 @@ class FP16TorchMixedPrecision(MixedPrecision):
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass

View File

@ -12,7 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@ -76,14 +76,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
@ -99,8 +99,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The model is going to be split to checkpoint shards. "
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
@ -271,11 +274,11 @@ class GeminiPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
@ -290,8 +293,12 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -197,17 +197,21 @@ class LowLevelZeroPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
@ -38,11 +38,11 @@ class Plugin(ABC):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass

View File

@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
@ -53,12 +52,27 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
class TorchDDPModel(ModelWrapper):
@ -128,11 +142,11 @@ class TorchDDPPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
@ -142,7 +156,8 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,9 +1,9 @@
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import warnings
from packaging import version
from torch.distributed import ProcessGroup
@ -69,7 +69,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
@ -87,13 +87,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
@ -194,15 +195,16 @@ class TorchFSDPPlugin(DPPluginBase):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'

View File

@ -103,7 +103,7 @@ class CheckpointIO(ABC):
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
variant: str = None,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""
@ -128,7 +128,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
@ -137,17 +137,20 @@ class CheckpointIO(ABC):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Load optimizer from checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@ -157,7 +160,7 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
@ -218,7 +221,7 @@ class CheckpointIO(ABC):
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to sharded checkpoint.
@ -251,7 +254,7 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load optimizer from sharded checkpoint.
@ -259,7 +262,6 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass

View File

@ -8,18 +8,26 @@ from typing import Iterator, Optional, OrderedDict, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
get_base_filenames,
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
shard_checkpoint,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
)
__all__ = ['GeneralCheckpointIO']
@ -44,12 +52,34 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load sharded optimizer with the given path to index file.
"""
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()
sharded_optimizer_loading_epilogue(optimizer)
def save_sharded_optimizer(
self,
@ -59,7 +89,54 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict = optimizer.state_dict()
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
self,
@ -74,7 +151,7 @@ class GeneralCheckpointIO(CheckpointIO):
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
@ -89,9 +166,9 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
@ -128,7 +205,7 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files:

View File

@ -1,8 +1,8 @@
import json
from pathlib import Path
from typing import Any, List, Union
import os
import json
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
@ -22,8 +22,10 @@ class CheckpointIndexFile:
def __init__(self, root_path=None) -> None:
self.root_path = root_path
self.metadata: dict = dict()
self.weight_map: dict = dict()
# use ordered dict to preserve the tensor checkpoint order
self.metadata: Dict = OrderedDict()
self.weight_map: Dict = OrderedDict()
@staticmethod
def from_file(index_path: Union[str, Path]):
@ -109,7 +111,7 @@ class CheckpointIndexFile:
return True
return False
def get_checkpoint_fileanames(self) -> List[str]:
def get_checkpoint_filenames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.
@ -157,6 +159,18 @@ class CheckpointIndexFile:
"""
return list(self.weight_map.keys())
def get_param_group_filename(self) -> Union[str, None]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename = self.metadata.get("param_groups", None)
if filename:
return str(self.root_path.joinpath(filename))
else:
return None
def write_index_file(self, save_index_file):
"""
Write index file.
@ -164,5 +178,5 @@ class CheckpointIndexFile:
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
content = json.dumps(index, indent=2) + "\n"
f.write(content)

View File

@ -1,17 +1,24 @@
# coding=utf-8
import re
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.tensor.d_tensor.d_tensor import DTensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
# General helper functions
@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -110,6 +117,56 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
yield current_block, current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
current_block = {}
current_block_size = 0
for param_id, state in states.items():
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
continue
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
@ -179,6 +236,102 @@ def load_state_dict_into_model(model: nn.Module,
model.__class__.__name__, "\n\t".join(error_msgs)))
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
"""
Load information of param_groups into an initialized optimizer.
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path)
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups = optimizer.param_groups
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g['params']) for g in param_groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Creating mapping from id to parameters.
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in param_groups)))
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({'param_groups': updated_groups})
return id_map
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
"""
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"):
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states = defaultdict(dict)
for k, v in state_dict.items():
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
else:
new_states[k] = v
optimizer.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
r"""Do the cleaning up work after state_dict has been loaded into optimizer
Args:
optimizer(Optimizer): An optimizer object whose state has just been loaded.
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)
# ======================================
# Helper functions for saving state dict
# ======================================
@ -203,6 +356,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
@ -392,28 +557,44 @@ def load_state_dict(checkpoint_file_path: Path):
return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0:
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
if prefix is not None and len(prefix) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
splits = splits[:-1] + [prefix] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base weight filenames
generate base model weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
weights_name = add_prefix(weights_name, prefix)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
save_index_file = add_prefix(save_index_file, prefix)
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME
param_group_file = add_prefix(param_group_file, prefix)
return states_name, save_index_file, param_group_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name

View File

@ -716,7 +716,10 @@ class _StateDictSharder:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
if self.current_block_size + tensor_size > self.max_shard_size:
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()

View File

@ -4,12 +4,15 @@ import pytest
import torch
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
import diffusers
MODELS = [diffusers.UNet2DModel]
HAS_REPO = True
from packaging import version
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
except:
MODELS = []
HAS_REPO = False
SKIP_UNET_TEST = False
from test_autochunk_diffuser_utils import run_test
@ -32,6 +35,10 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args
@pytest.mark.skipif(
SKIP_UNET_TEST,
reason="diffusers version > 0.10.2",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",

View File

@ -60,7 +60,7 @@ def test_unsharded_checkpoint(use_safetensors: bool):
@pytest.mark.parametrize('use_safetensors', [True, False])
def test_sharded_checkpoint(use_safetensors: bool):
def test_sharded_model_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
@ -100,3 +100,111 @@ def test_sharded_checkpoint(use_safetensors: bool):
# check for model and optimizer state dict recursively
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
def test_sharded_optimizer_checkpoint():
# create a model and optimizer
model = resnet18()
optimizer = Adam(model.parameters(), lr=0.001)
# create test data sample
x = torch.randn(1, 3, 224, 224)
# run fwd and bwd
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
# create temp directories for checkpoint
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
# create new model
new_model = resnet18()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
# check for model and optimizer state dict recursively
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
# continue running fwd and bwd
for _ in range(5):
y = new_model(x)
loss = y.sum()
loss.backward()
new_optimizer.step()
# save the newly got optimizer
ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
# create another new model
new_new_model = resnet18()
new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001)
ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True)
ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name))
# check for model and optimizer state dict recursively
check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict())
check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict())
def test_sharded_optimizer_multiple_param_groups():
# create a model and optimizer
model = resnet18()
optimizer = Adam([{
'params': model.layer1.parameters()
}, {
'params': model.layer2.parameters(),
'lr': 0.002
}],
lr=0.001)
# create test data sample
x = torch.randn(1, 3, 224, 224)
# run fwd and bwd
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
# create temp directories for checkpoint
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_dir = tempfile.TemporaryDirectory()
# save the model and optimizer
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10)
# create new model
new_model = resnet18()
new_optimizer = Adam([{
'params': new_model.layer1.parameters()
}, {
'params': new_model.layer2.parameters(),
'lr': 0.002
}],
lr=0.001)
ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True)
ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name))
# check for model and optimizer state dict recursively
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())

View File

@ -13,7 +13,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
@parameterize('size_per_shard', [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@ -38,10 +39,8 @@ def check_torch_ddp_checkpointIO(shard: bool):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
@ -55,7 +54,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)