[CheckpointIO] a uniform checkpoint I/O module (#1689)

pull/1825/head
ver217 2022-11-08 15:15:13 +08:00 committed by GitHub
parent 629172b319
commit 99870726b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 2111 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from .io import load, merge, redist, save
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)

View File

@ -0,0 +1,74 @@
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type
from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
def register(name: str):
assert name not in _backends, f'"{name}" is registered'
def wrapper(cls):
_backends[name] = cls
return cls
return wrapper
def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()
class CheckpointIOBackend(ABC):
def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []
@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass
@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass
@abstractmethod
def get_temp(self, base_name: str) -> str:
pass
@abstractmethod
def clean_temp(self) -> None:
pass
@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)
def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name
def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)

View File

@ -0,0 +1,9 @@
import re
GLOBAL_META_FILE_NAME = 'global_meta.bin'
MODEL_CKPT_FILE_NAME = 'model.bin'
OPTIM_CKPT_FILE_NAME = 'optim.bin'
META_CKPT_FILE_NAME = 'meta.bin'
OTHER_CKPT_FILE_NAME = 'other.bin'
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')

View File

@ -0,0 +1,227 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from torch import Tensor
from .distributed import merge_param, unmerge_param
from .meta import ParamDistMeta, RedistMeta
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
class CheckpointConvertor(ABC):
@abstractmethod
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
pass
@abstractmethod
def complete(self) -> None:
pass
class ModelCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int]) -> None:
super().__init__()
self.param_count = param_count
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
@abstractmethod
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
for k, tensor in state_dict.items():
self.buffer[k][rank] = tensor
converted_keys = set()
for k, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[k]:
tensors = []
dist_metas = []
for rank, tensor in rank_dict.items():
tensors.append(tensor)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][k])
self.convert_tensors(k, tensors, dist_metas)
converted_keys.add(k)
for k in converted_keys:
del self.buffer[k]
def complete(self) -> None:
assert len(self.buffer) == 0
class ModelCheckpointMerger(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
super().__init__(param_count)
self.sharder = ModelCheckpointSharder(max_shard_size)
self.save_fn = save_fn
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
shard = self.sharder.append(key, tensor)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class ModelCheckpointRedistor(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
redist_meta: RedistMeta) -> None:
super().__init__(param_count)
self.save_fns = save_fns
self.redist_meta = redist_meta
nprocs = len(save_fns)
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
if len(dist_metas) == 0:
# already global
tensor = tensors[0]
else:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[key][tp_rank][dp_rank]:
shard = self.sharders[rank].append(key, t)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())
class OptimizerCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__()
self.param_count = param_count
self.param_to_os = param_to_os
self.paired_os = paired_os
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
self.os_to_param = {v: k for k, v in param_to_os.items()}
@abstractmethod
def setup(self, param_groups: dict) -> None:
pass
@abstractmethod
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
self.setup(state_dict['param_groups'])
for idx, state in state_dict['state'].items():
self.buffer[idx][rank] = state
converted_indices = set()
for idx, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
states = []
dist_metas = []
for rank, state in rank_dict.items():
states.append(state)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
self.convert_states(idx, states, dist_metas)
converted_indices.add(idx)
for idx in converted_indices:
del self.buffer[idx]
def complete(self) -> None:
assert len(self.buffer) == 0
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fn = save_fn
self.sharder = None
def setup(self, param_groups: dict) -> None:
if self.sharder is None:
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(states)
new_state = {}
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
else:
new_state[state_key] = state_tensor
shard = self.sharder.append(idx, new_state)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
redist_meta: RedistMeta) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fns = save_fns
self.redist_meta = redist_meta
self.sharders: List[OptimizerCheckpointSharder] = []
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def setup(self, param_groups: dict) -> None:
if len(self.sharders) == 0:
nprocs = len(self.save_fns)
for _ in range(nprocs):
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
need_merge: bool = True
if len(dist_metas) == 0:
need_merge = False
else:
assert len(dist_metas) == len(states)
new_states = [{} for _ in range(len(self.save_fns))]
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
if need_merge:
tensor = merge_param([state[state_key] for state in states], dist_metas)
else:
tensor = state_tensor
for tp_rank, tensor_list in enumerate(
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
new_states[rank][state_key] = t
else:
for new_state in new_states:
new_state[state_key] = state_tensor
for rank, new_state in enumerate(new_states):
shard = self.sharders[rank].append(idx, new_state)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())

View File

@ -0,0 +1,127 @@
import torch
from numpy import prod
from torch import Tensor
from typing import List, Optional, Tuple
from collections import defaultdict
from .meta import ParamDistMeta, ParamRedistMeta
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
if not dist_metas[0].used_zero:
# tensors are replicate
return tensors[0]
numel = dist_metas[0].zero_numel
orig_shape = dist_metas[0].zero_orig_shape
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
return torch.cat(tensors).reshape(orig_shape)
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
for t in tensors[1:]:
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
if not dist_metas[0].used_tp:
# tensors are replicate
return tensors[0]
total_parts = prod(dist_meta.tp_num_parts)
assert dist_meta.tp_world_size == total_parts, \
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
for dim, num_parts in shard_info:
buffer = []
for start in range(0, len(tensors), num_parts):
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
tensors = buffer
assert len(tensors) == 1
return tensors[0]
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) > 0
# check world size
for dist_meta in dist_metas[1:]:
assert dist_meta.dp_world_size == dist_metas[
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
assert dist_meta.tp_world_size == dist_metas[
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
def deduplicate_params(tensors: List[Tensor],
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
unique_dist_meta = []
unique_idx = []
for i, dist_meta in enumerate(dist_metas):
if dist_meta not in unique_dist_meta:
unique_dist_meta.append(dist_meta)
unique_idx.append(i)
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
# validate parallel info
validate_parallel_info(dist_metas)
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
unflattened_tensors = []
# group zero params by tp rank
tensor_dict = defaultdict(list)
dist_meta_dict = defaultdict(list)
for t, dist_meta in zip(tensors, dist_metas):
tensor_dict[dist_meta.tp_rank].append(t)
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
assert len(tensor_dict
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
for tp_rank in tensor_dict.keys():
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_tp:
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
return [tensor]
total_parts = prod(redist_meta.tp_num_parts)
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
tensors = [tensor]
for dim, num_parts in shard_info:
buffer = []
for t in tensors:
assert t.size(dim) % num_parts == 0, \
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
buffer.extend(chunks)
tensors = buffer
assert len(tensors) == redist_meta.tp_world_size
return tensors
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_zero:
return [tensor] * redist_meta.dp_world_size
tensors: List[Optional[Tensor]] = [
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
]
offsets = redist_meta.zero_offsets + [tensor.numel()]
for i, offset in enumerate(offsets[:-1]):
end = offsets[i + 1]
tensors.append(tensor.view(-1)[offset:end])
if len(tensors) < redist_meta.dp_world_size:
tensors.extend([
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for _ in range(redist_meta.dp_world_size - len(tensors))
])
assert len(tensors) == redist_meta.dp_world_size
return tensors
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
tensors = split_tp_param(tensor, redist_meta)
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
return tensors

View File

@ -0,0 +1,170 @@
import warnings
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch.distributed as dist
from torch.nn import Module
from torch.optim import Optimizer
from .backend import get_backend
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
OptimizerCheckpointRedistor)
from .meta import ParamDistMeta, RedistMeta
from .utils import build_checkpoints, optimizer_load_state_dict
def save(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk',
**kwargs: Any) -> None:
io_backend = get_backend(backend)
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
if world_size == 1:
# global doesn't need dist_meta
dist_meta = None
else:
assert dist_meta is not None
max_shard_size = int(max_shard_size_gb * 1024**3)
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
param_to_os, dist_meta)
writer = io_backend.get_writer(path, overwrite, rank, world_size)
writer.save_others(kwargs)
for model_checkpoint in model_checkpoints:
writer.save_model(model_checkpoint)
for optimizer_checkpoint in optimizer_checkpoints:
writer.save_optimizer(optimizer_checkpoint)
writer.save_meta(meta_checkpoint)
def merge(path: str,
output_path: str,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
reader = io_backend.get_reader(path)
if len(reader.meta_list) == 1:
# already global
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
return False
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
writer = io_backend.get_writer(output_path, overwrite=overwrite)
writer.save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
dist_meta_list)
_convert_shards(
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
reader.load_optimizers(), dist_meta_list)
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def redist(path: str,
output_path: str,
redist_meta: RedistMeta,
dist_metas: List[Dict[str, ParamDistMeta]],
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
nprocs = len(dist_metas)
reader = io_backend.get_reader(path)
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
do_redist: bool = False
if len(dist_meta_list) == nprocs:
for a, b in zip(dist_metas, dist_meta_list):
if a != b:
do_redist = True
break
else:
do_redist = True
if not do_redist:
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
return False
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
writers[0].save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
reader.load_models(), dist_meta_list)
_convert_shards(
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
for writer, dist_meta in zip(writers, dist_metas):
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for shard_dict in shard_generator:
convertor.append(shard_dict, dist_meta_list)
convertor.complete()
def load(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
redist_meta: Optional[RedistMeta] = None,
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
max_shard_size_gb: float = 0.0,
backend: str = 'disk') -> dict:
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
rank: int = dist.get_rank() if dist.is_initialized() else 0
is_main_process: bool = rank == 0
# validate args
if redist_meta is None or dist_metas is None:
assert is_global
io_backend = get_backend(backend)
read_path: str = path
if is_main_process:
# pre-process checkpoints
temp_path = io_backend.get_temp(path)
if is_global:
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
else:
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
if wrote:
read_path = temp_path
if not is_global:
bcast_list = [read_path] if is_main_process else [None]
dist.broadcast_object_list(bcast_list)
read_path = bcast_list[0]
reader = io_backend.get_reader(read_path)
# load model
for shard in reader.load_model(rank):
model.load_state_dict(shard, strict=False)
if optimizer is not None:
for shard in reader.load_optimizer(rank):
# optimizer.load_state_dict(shard)
optimizer_load_state_dict(optimizer, shard)
others_dict = reader.load_others()
if not is_global:
dist.barrier()
# clean up temp
if is_main_process:
io_backend.clean_temp()
return others_dict

View File

@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import List, Optional, Set, Dict
@dataclass
class ParamDistMeta:
# parallel info
dp_rank: int
dp_world_size: int
tp_rank: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_numel: Optional[int] = None
zero_orig_shape: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_numel is not None and self.zero_orig_shape is not None
@property
def parallel_meta(self) -> tuple:
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
@property
def tp_meta(self) -> tuple:
return self.tp_shard_dims, self.tp_num_parts
@property
def zero_meta(self) -> tuple:
return self.zero_numel, self.zero_orig_shape
@staticmethod
def from_dict(d: dict) -> 'ParamDistMeta':
return ParamDistMeta(**d)
@dataclass
class ParamRedistMeta:
# parallel info
dp_world_size: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_start_dp_rank: Optional[int] = None
zero_offsets: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
@dataclass
class RankRedistMeta:
dp_rank: int
tp_rank: int
pp_rank: int
@dataclass
class PipelineRedistMeta:
params: Set[str]
@dataclass
class RedistMeta:
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
pipeline_meta: List[PipelineRedistMeta]
param_meta: Dict[str, ParamRedistMeta]

View File

@ -0,0 +1,131 @@
import os
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generator, List, Optional, Tuple
import torch
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
from .meta import ParamDistMeta
from .utils import is_duplicated_list
class CheckpointReader(ABC):
def __init__(self, base_name: str) -> None:
super().__init__()
self.base_name = base_name
self.meta_list = []
@abstractmethod
def read(self, name: str) -> dict:
pass
@abstractmethod
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
pass
@abstractmethod
def load_model(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_models(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_others(self) -> dict:
pass
class DiskCheckpointReader(CheckpointReader):
def __init__(self, base_name: str) -> None:
super().__init__(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
global_meta = self.read(GLOBAL_META_FILE_NAME)
for meta_file_name in global_meta['meta']:
meta = self.read(meta_file_name)
if meta.get('dist_meta', None) is None:
# only global checkpoint can have empty dist_meta
assert len(global_meta['meta']) == 1
self.meta_list.append(meta)
def read(self, name: str) -> dict:
return torch.load(os.path.join(self.base_name, name))
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
None), meta.get('paired_os', None))
for meta in self.meta_list]
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
# reduce param_count
param_count = Counter(p for params in params_list for p in params)
# validate param_to_os
assert is_duplicated_list(param_to_os_list)
assert is_duplicated_list(paired_os_list)
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
meta = self.meta_list[rank]
checkpoint_names = meta.get(shard_type, [])
for name in checkpoint_names:
yield self.read(name)
def load_model(self, rank: int) -> Generator[dict, None, None]:
return self._load_shard('model', rank)
def load_models(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
model_checkpoint_names = meta.get('model', [])
if indices[i] < len(model_checkpoint_names):
shards[i] = self.read(model_checkpoint_names[indices[i]])
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
param_groups = None
for shard in self._load_shard('optimizer', rank):
if param_groups is None:
param_groups = shard['param_groups']
else:
shard['param_groups'] = param_groups
yield shard
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
param_groups = []
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
optimizer_checkpoint_names = meta.get('optimizer', [])
if indices[i] < len(optimizer_checkpoint_names):
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
if indices[i] == 0:
param_groups.append(shards[i]['param_groups'])
else:
shards[i]['param_groups'] = param_groups[i]
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_others(self) -> dict:
return self.read(OTHER_CKPT_FILE_NAME)

View File

@ -0,0 +1,223 @@
import warnings
from copy import deepcopy
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from .meta import ParamDistMeta
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
if arg is not None:
return fn(arg)
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
# ensure all params in optimizer are in model state dict
params_set = set(id(p) for p in model.parameters())
for group in optimizer.param_groups:
for p in group['params']:
assert id(p) in params_set
param_mappings = {}
start_index = 0
def get_group_mapping(group):
nonlocal start_index
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
start_index += len(group['params'])
for g in optimizer.param_groups:
get_group_mapping(g)
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
size = 0
for v in state.values():
if isinstance(v, Tensor):
size += v.numel() * v.element_size()
return size
class ModelCheckpointSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, Tensor] = {}
self.buffer_size: int = 0
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {}
self.buffer_size = 0
self.buffer[key] = tensor
self.buffer_size += tensor.numel() * tensor.element_size()
return retval
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
shards = []
for key, tensor in state_dict.items():
shard = self.append(key, tensor)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer) > 0 else None
class OptimizerCheckpointSharder:
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
self.buffer_size: int = 0
self.returned_first: bool = False
def append(self, key: int, state: dict) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {'state': {}}
self.buffer_size = 0
self.buffer['state'][key] = state
self.buffer_size += compute_optimizer_state_size(state)
return retval
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
shards = []
for key, state in state_dict['state'].items():
shard = self.append(key, state)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer['state']) > 0 else None
def shard_checkpoint(max_shard_size: int,
model_state_dict: Dict[str, Tensor],
optimizer_state_dict: Optional[dict] = None,
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
has_optimizer: bool = False
if optimizer_state_dict is not None:
assert param_to_os is not None
os_to_param = {v: k for k, v in param_to_os.items()}
for os_key in optimizer_state_dict['state'].keys():
assert os_key in os_to_param
assert os_to_param[os_key] in model_state_dict
has_optimizer = True
model_sharder = ModelCheckpointSharder(max_shard_size)
model_shards = model_sharder.extend(model_state_dict)
run_if_not_none(model_shards.append, model_sharder.complete())
if not has_optimizer:
return model_shards, []
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
return model_shards, optimizer_shards
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
os_to_param = {v: k for k, v in param_to_os.items()}
paired_os = {}
for idx, state in optimizer_state_dict['state'].items():
paired_os[idx] = {}
p = model_state_dict[os_to_param[idx]]
for k, v in state.items():
if isinstance(v, Tensor) and v.shape == p.shape:
paired_os[idx][k] = True
else:
paired_os[idx][k] = False
return paired_os
def build_checkpoints(max_size: int,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
save_global = dist_meta is None
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict() if optimizer else None
meta = {'dist_meta': dist_meta}
if optimizer:
param_to_os = param_to_os or get_param_to_os(model, optimizer)
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
meta['param_to_os'] = param_to_os
meta['paired_os'] = paired_os
if not save_global and eliminate_replica:
# filter dp replicated params
model_state_dict = {
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
if optimizer:
optimizer_state_dict['state'] = {
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
for k in model_state_dict.keys()
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
meta['params'] = list(model_state_dict.keys())
if len(model_state_dict) == 0:
warnings.warn('model state dict is empty, checkpoint is not saved')
return [], [], meta
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
param_to_os)
return model_checkpoints, optimizer_checkpoints, meta
def is_duplicated_list(list_: List[Any]) -> bool:
if len(list_) == 0:
return True
elem = list_[0]
for x in list_[1:]:
if x != elem:
return False
return True
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
for k, v in src_state.items():
if k in dest_state:
old_v = dest_state[k]
if isinstance(old_v, Tensor):
old_v.copy_(v)
else:
dest_state[k] = v
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
state_dict = deepcopy(state_dict)
groups = optimizer.param_groups
saved_groups = state_dict['param_groups']
idx_to_p: Dict[str, Parameter] = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
unexpected_keys = []
error_msgs = []
for idx, state in state_dict['state'].items():
if idx in idx_to_p:
old_state = optimizer.state[idx_to_p[idx]]
copy_optimizer_state(state, old_state)
else:
unexpected_keys.append(idx)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
"\n\t".join(error_msgs)))

View File

@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from typing import Optional
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
import torch
import os
class CheckpointWriter(ABC):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__()
self.base_name = base_name
self.overwrite = overwrite
self.rank = rank
self.world_size = world_size
self.is_distributed = world_size > 1
self.is_main_process = rank == 0
@abstractmethod
def write(self, name: str, state_dict: dict) -> None:
pass
@abstractmethod
def save_model(self, model_checkpoint: dict) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
pass
@abstractmethod
def save_meta(self, meta_checkpoint: dict) -> None:
pass
@abstractmethod
def save_others(self, kwargs: dict) -> None:
pass
class DiskCheckpointWriter(CheckpointWriter):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__(base_name, overwrite, rank, world_size)
if not os.path.exists(base_name):
os.makedirs(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
self.model_checkpoint_names = []
self.optimizer_checkpoint_names = []
self.is_meta_saved: bool = False
self._save_global_meta()
def write(self, name: str, state_dict: dict) -> None:
path = os.path.join(self.base_name, name)
if os.path.exists(path) and not self.overwrite:
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
torch.save(state_dict, path)
def _save_global_meta(self) -> None:
if self.is_main_process:
global_meta = {'meta': []}
if self.is_distributed:
for i in range(self.world_size):
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
else:
global_meta['meta'].append(META_CKPT_FILE_NAME)
self.write(GLOBAL_META_FILE_NAME, global_meta)
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
checkpoint_name = base_name
if self.is_distributed:
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
if shard_idx is not None:
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
return checkpoint_name
def save_model(self, model_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save model after saving meta'
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
self.write(name, model_checkpoint)
self.model_checkpoint_names.append(name)
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
self.write(name, optimizer_checkpoint)
self.optimizer_checkpoint_names.append(name)
def save_meta(self, meta_checkpoint: dict) -> None:
if len(self.model_checkpoint_names) > 0:
meta_checkpoint['model'] = self.model_checkpoint_names
if len(self.optimizer_checkpoint_names) > 0:
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
self.is_meta_saved = True
def save_others(self, kwargs: dict) -> None:
if self.is_main_process:
self.write(OTHER_CKPT_FILE_NAME, kwargs)

View File

@ -0,0 +1,120 @@
import torch
import torch.nn as nn
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.utils import build_checkpoints
from torch.optim import Adam
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def test_global_model():
model = DummyModel()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 0
assert meta['dist_meta'] is None
orig_state_dict = model.state_dict()
global_state_dict = model_checkpoints[0]
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
for k, v in orig_state_dict.items():
assert torch.equal(v, global_state_dict[k])
def test_global_model_shard():
model = DummyModel()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
assert len(model_checkpoints) == 2
assert len(optimizer_checkpoints) == 0
assert meta['dist_meta'] is None
orig_state_dict = model.state_dict()
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
for k, v in orig_state_dict.items():
for state_dict in model_checkpoints:
if k in state_dict:
assert torch.equal(v, state_dict[k])
def test_global_optimizer():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
assert len(optimizer_checkpoints) == 1
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
for state in meta['paired_os'].values():
for k, is_paired in state.items():
if k == 'step':
assert not is_paired
else:
assert is_paired
orig_state_dict = optimizer.state_dict()
state_dict = optimizer_checkpoints[0]
for k, orig_state in orig_state_dict['state'].items():
state = state_dict['state'][k]
for v1, v2 in zip(orig_state.values(), state.values()):
if isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2)
else:
assert v2 == v2
assert orig_state_dict['param_groups'] == state_dict['param_groups']
def test_global_optimizer_shard():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
assert len(optimizer_checkpoints) == 2
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
orig_state_dict = optimizer.state_dict()
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
optimizer_checkpoints[1]['state'].keys())
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
for k, orig_state in orig_state_dict['state'].items():
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
'state'] else optimizer_checkpoints[1]['state'][k]
for v1, v2 in zip(orig_state.values(), state.values()):
if isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
def test_dist_model_optimizer():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
assert dist_meta == meta['dist_meta']
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 1
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
assert dist_meta == meta['dist_meta']
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 1
if __name__ == '__main__':
test_global_model()
test_global_model_shard()
test_global_optimizer()
test_global_optimizer_shard()
test_dist_model_optimizer()

View File

@ -0,0 +1,188 @@
from copy import deepcopy
from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta)
from torch import Tensor
from torch.nn import Module
from torch.optim import Adam, Optimizer
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
for k, v in a.items():
assert torch.equal(v, b[k])
def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None:
assert set(a['state'].keys()) == set(b['state'].keys())
for k, state in a['state'].items():
b_state = b['state'][k]
for v1, v2 in zip(state.values(), b_state.values()):
if isinstance(v1, Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
if not ignore_param_gruops:
assert a['param_groups'] == b['param_groups']
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
with torch.no_grad():
for p in model.parameters():
p.fill_(scalar)
for state in optimizer.state.values():
for v in state.values():
if isinstance(v, Tensor):
v.fill_(scalar)
def get_dist_metas(nprocs: int, zero: bool = False):
dp_world_size = nprocs // 2
dist_metas = []
for rank in range(nprocs):
if zero:
dist_metas.append({
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
})
else:
dist_metas.append({
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
})
return dist_metas
def get_redist_meta(nprocs: int):
dp_world_size = nprocs // 2
rank_meta = {
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
}
param_meta = {
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamRedistMeta(dp_world_size, 1)
}
return RedistMeta(rank_meta, [], param_meta)
@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
def test_save_global_load_global(max_shard_size_gb: float):
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
new_model, new_optimizer = prepare_model_optim()
load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
check_model_state_dict(model.state_dict(), new_model.state_dict())
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
def run_dist(rank, world_size, port, func):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func()
def launch_dist(fn, world_size: int):
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
def save_dist(dir_name: str, zero: bool):
model, optmizer = prepare_model_optim(shard=True, zero=zero)
reset_model_optim(model, optmizer)
world_size = dist.get_world_size()
rank = dist.get_rank()
save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank])
def load_and_check_dist(dir_name: str):
world_size = dist.get_world_size()
model, optmizer = prepare_model_optim(shard=True)
reset_model_optim(model, optmizer)
model_state_dict = deepcopy(model.state_dict())
optimizer_state_dict = deepcopy(optmizer.state_dict())
reset_model_optim(model, optmizer, 1)
load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size))
check_model_state_dict(model_state_dict, model.state_dict())
check_optim_state_dict(optimizer_state_dict, optmizer.state_dict())
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_global_load_dist():
model, optimizer = prepare_model_optim()
reset_model_optim(model, optimizer)
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_dist_load_dist():
with TemporaryDirectory() as dir_name:
# save tp + dp
fn = partial(save_dist, dir_name, False)
launch_dist(fn, 2)
# load tp + dp
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 2)
with TemporaryDirectory() as dir_name:
# save tp + zero
fn = partial(save_dist, dir_name, True)
launch_dist(fn, 4)
# load tp + dp
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 2)
launch_dist(fn, 4)
if __name__ == '__main__':
test_save_global_load_global(80 / 1024**3)
test_save_global_load_global(0)
test_save_global_load_dist()
test_save_dist_load_dist()

View File

@ -0,0 +1,127 @@
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import save, merge
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tempfile import TemporaryDirectory
from torch.optim import Adam
from functools import partial
import torch
import os
import pytest
import colossalai
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def test_merge_global():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 0
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 0
def run_dist(rank, world_size, port, func):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
'size': 2
}
}},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
func()
def run_save_dist(dir_name: str, zero: bool):
model, optmizer = prepare_model_optim(shard=True, zero=zero)
rank = dist.get_rank()
dp_world_size = dist.get_world_size() // 2
if not zero:
dist_metas = {
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
}
else:
dist_metas = {
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
}
save(dir_name, model, optmizer, dist_meta=dist_metas)
@pytest.mark.dist
@pytest.mark.parametrize("zero", [False, True])
@rerun_if_address_is_in_use()
def test_merge_tp_dp(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 5
global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 1
meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
assert meta['dist_meta'] is None
assert len(meta['params']) == 2
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
assert len(model_state_dict) == 2
assert model_state_dict['fc.weight'].size(1) == 20
optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
if __name__ == '__main__':
test_merge_global()
test_merge_tp_dp(False)
test_merge_tp_dp(True)

View File

@ -0,0 +1,101 @@
import torch
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
def test_unflatten_zero_param_even() -> None:
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).chunk(4))
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
assert torch.equal(orig_tensor, unflattened_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_unflatten_zero_param_uneven() -> None:
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
assert torch.equal(orig_tensor, unflattened_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_1d_row() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_1d_col() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_2d() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_2d_reverse() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_merge_param_hybrid() -> None:
dist_metas = [
ParamDistMeta(i % 2,
2,
i // 2,
6,
tp_shard_dims=[1, 0],
tp_num_parts=[3, 2],
zero_numel=4,
zero_orig_shape=[2, 2]) for i in range(12)
]
orig_tensor = torch.rand(4, 6)
tensors = [
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
for chunk in t.contiguous().reshape(-1).split([1, 3])
]
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_merge_param_dummy() -> None:
dist_metas = [ParamDistMeta(0, 1, 0, 1)]
orig_tensor = torch.rand(4, 6)
merged_tensor = merge_param([orig_tensor], dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
if __name__ == '__main__':
test_unflatten_zero_param_even()
test_unflatten_zero_param_uneven()
test_gather_tp_param_1d_row()
test_gather_tp_param_1d_col()
test_gather_tp_param_2d()
test_gather_tp_param_2d_reverse()
test_merge_param_hybrid()
test_merge_param_dummy()

View File

@ -0,0 +1,149 @@
import os
from functools import partial
from tempfile import TemporaryDirectory
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import redist, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta,
RedistMeta)
from torch.optim import Adam
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def get_dist_metas(nprocs: int, zero: bool = False):
dp_world_size = nprocs // 2
dist_metas = []
for rank in range(nprocs):
if zero:
dist_metas.append({
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
})
else:
dist_metas.append({
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
})
return dist_metas
def get_redist_meta(nprocs: int):
dp_world_size = nprocs // 2
rank_meta = {
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
}
param_meta = {
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamRedistMeta(dp_world_size, 1)
}
return RedistMeta(rank_meta, [], param_meta)
def check_checkpoint_shape(dir_name: str):
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
for meta_name in global_meta['meta']:
meta = torch.load(os.path.join(dir_name, meta_name))
assert meta['dist_meta'] is not None
assert len(meta['params']) == 2
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
assert len(model_state_dict) == 2
assert model_state_dict['fc.weight'].size(1) == 10
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
def test_global_to_dist():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
check_checkpoint_shape(output_dir)
def run_dist(rank, world_size, port, func):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
'size': 2
}
}},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
func()
def run_save_dist(dir_name: str, zero: bool):
model, optmizer = prepare_model_optim(shard=True, zero=zero)
rank = dist.get_rank()
save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank])
@pytest.mark.dist
@pytest.mark.parametrize("zero", [False, True])
@rerun_if_address_is_in_use()
def test_dist_to_dist(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
if not zero:
assert len(os.listdir(output_dir)) == 0
else:
check_checkpoint_shape(output_dir)
if __name__ == '__main__':
test_global_to_dist()
test_dist_to_dist(False)
test_dist_to_dist(True)

View File

@ -0,0 +1,147 @@
import os
from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from torch import Tensor
from torch.optim import Adam
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
for k, v in a.items():
assert torch.equal(v, b[k])
def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None:
assert set(a['state'].keys()) == set(b['state'].keys())
for k, state in a['state'].items():
b_state = b['state'][k]
for v1, v2 in zip(state.values(), b_state.values()):
if isinstance(v1, Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
if not ignore_param_gruops:
assert a['param_groups'] == b['param_groups']
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim():
model = DummyModel()
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def test_overwrite():
model = DummyModel()
with TemporaryDirectory() as dir_name:
with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
pass
with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
save(dir_name, model)
def test_save_global():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
assert len(os.listdir(dir_name)) == 5
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
assert len(meta['model']) == 1
assert len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
check_model_state_dict(model.state_dict(), model_state_dict)
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
assert len(other_state_dict) == 0
def test_save_global_shard():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
assert len(os.listdir(dir_name)) == 7
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
check_optim_state_dict(
optimizer.state_dict(), {
'state': {
**optimizer_state_dicts[0]['state'],
**optimizer_state_dicts[1]['state']
},
'param_groups': optimizer_state_dicts[0]['param_groups']
})
def run_dist(rank, world_size, port, func):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func()
def run_save_dist(dir_name):
model, optmizer = prepare_model_optim()
dist_metas = {
'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
}
save(dir_name, model, optmizer, dist_meta=dist_metas)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_dist():
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name)
world_size = 2
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
assert len(os.listdir(dir_name)) == 8
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 2
for rank, meta_name in enumerate(global_meta['meta']):
meta = torch.load(os.path.join(dir_name, meta_name))
assert meta.get('dist_meta', None) is not None
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
assert len(model_state_dict) == 2
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict
if __name__ == '__main__':
test_overwrite()
test_save_global()
test_save_global_shard()
test_save_dist()

View File

@ -0,0 +1,137 @@
import torch
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
def test_flatten_zero_param_even() -> None:
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).chunk(4))
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
assert len(tensors) == len(flat_tensors)
for t, st in zip(tensors, flat_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1
unmerged_tensors = unmerged_tensors[0]
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert torch.equal(t, tl)
def test_flatten_zero_param_uneven() -> None:
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
flat_tensors = flat_tensors[1:-1]
assert len(tensors) == len(flat_tensors)
for t, st in zip(tensors, flat_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1
unmerged_tensors = unmerged_tensors[0]
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
unmerged_tensors = unmerged_tensors[1:-1]
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert torch.equal(t, tl)
def test_split_tp_param_1d_row() -> None:
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_1d_col() -> None:
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_2d() -> None:
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_2d_reverse() -> None:
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_unmerge_param_hybrid() -> None:
redist_meta = ParamRedistMeta(2,
6,
tp_shard_dims=[1, 0],
tp_num_parts=[3, 2],
zero_start_dp_rank=0,
zero_offsets=[0, 1])
orig_tensor = torch.rand(4, 6)
tensors = [
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
for chunk in t.contiguous().reshape(-1).split([1, 3])
]
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
for tp_rank in range(6):
for dp_rank in range(2):
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
def test_unmerge_param_dummy() -> None:
redist_meta = ParamRedistMeta(1, 1)
orig_tensor = torch.rand(4, 6)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
if __name__ == '__main__':
test_flatten_zero_param_even()
test_flatten_zero_param_uneven()
test_split_tp_param_1d_row()
test_split_tp_param_1d_col()
test_split_tp_param_2d()
test_split_tp_param_2d_reverse()
test_unmerge_param_hybrid()
test_unmerge_param_dummy()