From 5452df63c58385985fcd89749f266109eb9ea8b8 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 18 May 2023 20:05:59 +0800 Subject: [PATCH] [plugin] torch ddp plugin supports sharded model checkpoint (#3775) * [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug info --- colossalai/booster/plugin/torch_ddp_plugin.py | 12 +++- .../checkpoint_io/checkpoint_io_base.py | 6 +- colossalai/checkpoint_io/utils.py | 62 +++++++++++-------- .../test_plugin/test_low_level_zero_plugin.py | 7 ++- .../test_torch_ddp_checkpoint_io.py | 50 ++++++++++----- 5 files changed, 86 insertions(+), 51 deletions(-) diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 99cd2f779..b317ccf48 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model(self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + variant: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + if self.coordinator.is_master(): + super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors) + class TorchDDPModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9cf344ecc..fbc8fc542 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Union -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -84,9 +83,8 @@ class CheckpointIO(ABC): # containing no distributed tensors, dtensor -> full tensor conversion # should be done offline via our CLI # the existence of index file means it is a sharded checkpoint - ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - + # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ee4bd72e8..435feda4a 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,10 +1,12 @@ # coding=utf-8 +import re from pathlib import Path +from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple + import torch import torch.nn as nn -from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator + from colossalai.tensor.d_tensor.d_tensor import DTensor -import re SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" # General helper functions # ====================================== + def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. @@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float: """ return tensor.numel() * tensor.element_size() / 1024 / 1024 + def is_safetensors_available() -> bool: """ Check whether safetensors is available. @@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # Helper functions for saving shard file # ====================================== def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: - """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It current_block_size = 0 current_block[key] = weight current_block_size += weight_size - + if ret_block != None: yield ret_block, ret_block_size yield current_block, current_block_size -def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): """ load shard state dict into model """ if use_safetensors and not checkpoint_file.suffix == ".safetensors": raise Exception("load the model using `safetensors`, but no file endwith .safetensors") if use_safetensors: - from safetensors.torch import safe_open from safetensors.torch import load_file as safe_load_file + from safetensors.torch import safe_open with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() if metadata["format"] != "pt": raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") return safe_load_file(checkpoint_file) else: return torch.load(checkpoint_file) - -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True): + + +def load_state_dict_into_model(model: nn.Module, + state_dict: torch.Tensor, + missing_keys: List, + strict: bool = False, + load_sub_module: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. + this module and its descendants. Args: state_dict (dict): a dict containing parameters and @@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi if strict: if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys)) + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys)) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - + model.__class__.__name__, "\n\t".join(error_msgs))) + + # ====================================== # Helper functions for saving state dict # ====================================== @@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: return True, index_files[0] else: return False, None + else: + raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.') def load_state_dict(checkpoint_file_path: Path): @@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch return torch.load(checkpoint_file_path) - def add_variant(weights_name: str, variant: Optional[str] = None) -> str: @@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name -def get_base_filenames(variant: str=None, use_safetensors: bool=False): - """ - generate base weight filenames - """ - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) +def get_base_filenames(variant: str = None, use_safetensors: bool = False): + """ + generate base weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = add_variant(save_index_file, variant) + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) + + return weights_name, save_index_file - return weights_name, save_index_file def get_shard_filename(weights_name: str, idx: int): """ @@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int): """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") - return shard_file \ No newline at end of file + return shard_file diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index d84b96f77..f70f27be2 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,9 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] +_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] # These models will get stuck _STUCK_MODELS = [ 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', @@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): skipped_models.append(name) continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() if err is None: @@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 2, early_stop=early_stop) + spawn(run_dist, 4, early_stop=early_stop) if __name__ == '__main__': diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 3c05ea9f1..8a4217941 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -1,6 +1,7 @@ import tempfile import torch +import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD from torchvision.models import resnet18 @@ -8,12 +9,12 @@ from torchvision.models import resnet18 import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO from colossalai.interface import OptimizerWrapper -from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn +from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -def check_torch_ddp_checkpointIO(): +@parameterize('shard', [True, False]) +def check_torch_ddp_checkpointIO(shard: bool): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -34,23 +35,38 @@ def check_torch_ddp_checkpointIO(): optimizer.step() scheduler.step() - optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() - lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() - ckpt_io = TorchDDPCheckpointIO() - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) - ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name) + with tempfile.TemporaryDirectory() as tempdir: + obj = [tempdir] + dist.broadcast_object_list(obj, src=0) + tempdir = obj[0] # use the same directory on all ranks - new_model = resnet18() - new_optimizer = SGD((new_model.parameters()), lr=0.001) - new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) - _, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler) + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" + booster.save_model(model, model_ckpt_path, shard=shard) + if not shard: + # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint + booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + dist.barrier() - if ckpt_io.coordinator.is_master(): - ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + new_model = resnet18() + new_optimizer = SGD((new_model.parameters()), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) + new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model, + new_optimizer, + lr_scheduler=new_scheduler) - ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name) - check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + if not shard: + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + + dist.barrier() def run_dist(rank, world_size, port):