mirror of https://github.com/hpcaitech/ColossalAI
[plugin] torch ddp plugin supports sharded model checkpoint (#3775)
* [plugin] torch ddp plugin add save sharded model * [test] fix torch ddp ckpt io test * [test] fix torch ddp ckpt io test * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] add debug info * [test] fix low level zero plugin test * [test] fix low level zero plugin test * [test] remove debug infopull/3780/head
parent
2703a37ac9
commit
5452df63c5
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -50,6 +50,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -84,9 +83,8 @@ class CheckpointIO(ABC):
|
|||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
# should be done offline via our CLI
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
ckpt_path = Path(checkpoint)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
||||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# coding=utf-8
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
import re
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -15,6 +17,7 @@ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
|||
# General helper functions
|
||||
# ======================================
|
||||
|
||||
|
||||
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||
"""
|
||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||
|
@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
|||
"""
|
||||
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
||||
|
||||
|
||||
def is_safetensors_available() -> bool:
|
||||
"""
|
||||
Check whether safetensors is available.
|
||||
|
@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
|
|||
current_block_size = 0
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard state dict into model
|
||||
"""
|
||||
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
|
||||
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
|
||||
if use_safetensors:
|
||||
from safetensors.torch import safe_open
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||
)
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
state_dict: torch.Tensor,
|
||||
missing_keys: List,
|
||||
strict: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
this module and its descendants.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
|
@ -166,11 +173,12 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
|||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in unexpected_keys))
|
||||
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
|
@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
|||
return True, index_files[0]
|
||||
else:
|
||||
return False, None
|
||||
else:
|
||||
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file_path: Path):
|
||||
|
@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path):
|
|||
else:
|
||||
# load with torch
|
||||
return torch.load(checkpoint_file_path)
|
||||
|
||||
|
||||
|
||||
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
|
@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
|
@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int):
|
|||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
return shard_file
|
||||
return shard_file
|
||||
|
|
|
@ -11,9 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`']
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
|
||||
# These models have no parameters
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
|
||||
# These models will get stuck
|
||||
_STUCK_MODELS = [
|
||||
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
|
||||
|
@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||
skipped_models.append(name)
|
||||
continue
|
||||
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
|
@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_low_level_zero_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 2, early_stop=early_stop)
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD
|
||||
from torchvision.models import resnet18
|
||||
|
@ -8,12 +9,12 @@ from torchvision.models import resnet18
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_torch_ddp_checkpointIO():
|
||||
@parameterize('shard', [True, False])
|
||||
def check_torch_ddp_checkpointIO(shard: bool):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
|
@ -34,23 +35,38 @@ def check_torch_ddp_checkpointIO():
|
|||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
ckpt_io = TorchDDPCheckpointIO()
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
||||
ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name)
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
obj = [tempdir]
|
||||
dist.broadcast_object_list(obj, src=0)
|
||||
tempdir = obj[0] # use the same directory on all ranks
|
||||
|
||||
new_model = resnet18()
|
||||
new_optimizer = SGD((new_model.parameters()), lr=0.001)
|
||||
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
|
||||
_, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler)
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
new_model = resnet18()
|
||||
new_optimizer = SGD((new_model.parameters()), lr=0.001)
|
||||
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
|
||||
new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model,
|
||||
new_optimizer,
|
||||
lr_scheduler=new_scheduler)
|
||||
|
||||
ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue