[plugin] torch ddp plugin supports sharded model checkpoint (#3775)

* [plugin] torch ddp plugin add save sharded model

* [test] fix torch ddp ckpt io test

* [test] fix torch ddp ckpt io test

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] add debug info

* [test] fix low level zero plugin test

* [test] fix low level zero plugin test

* [test] remove debug info
pull/3780/head
Hongxin Liu 2023-05-18 20:05:59 +08:00 committed by GitHub
parent 2703a37ac9
commit 5452df63c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 51 deletions

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
class TorchDDPModel(ModelWrapper):

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Optional
from typing import Optional, Union
import torch
import torch.nn as nn
@ -84,9 +83,8 @@ class CheckpointIO(ABC):
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model

View File

@ -1,10 +1,12 @@
# coding=utf-8
import re
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
# General helper functions
# ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
this module and its descendants.
Args:
state_dict (dict): a dict containing parameters and
@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
model.__class__.__name__, "\n\t".join(error_msgs)))
# ======================================
# Helper functions for saving state dict
# ======================================
@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
def load_state_dict(checkpoint_file_path: Path):
@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file
return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int):
"""
@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int):
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file
return shard_file

View File

@ -11,9 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`']
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
skipped_models.append(name)
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None:
@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
spawn(run_dist, 4, early_stop=early_stop)
if __name__ == '__main__':

View File

@ -1,6 +1,7 @@
import tempfile
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torchvision.models import resnet18
@ -8,12 +9,12 @@ from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO
from colossalai.interface import OptimizerWrapper
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
def check_torch_ddp_checkpointIO():
@parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@ -34,23 +35,38 @@ def check_torch_ddp_checkpointIO():
optimizer.step()
scheduler.step()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
ckpt_io = TorchDDPCheckpointIO()
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name)
with tempfile.TemporaryDirectory() as tempdir:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
new_model = resnet18()
new_optimizer = SGD((new_model.parameters()), lr=0.001)
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
_, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler)
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
if ckpt_io.coordinator.is_master():
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
new_model = resnet18()
new_optimizer = SGD((new_model.parameters()), lr=0.001)
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model,
new_optimizer,
lr_scheduler=new_scheduler)
ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
dist.barrier()
def run_dist(rank, world_size, port):