mirror of https://github.com/hpcaitech/ColossalAI
[CheckpointIO] a uniform checkpoint I/O module (#1689)
parent
629172b319
commit
99870726b1
|
@ -0,0 +1,2 @@
|
|||
from .io import load, merge, redist, save
|
||||
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
|
|
@ -0,0 +1,74 @@
|
|||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .reader import CheckpointReader, DiskCheckpointReader
|
||||
from .writer import CheckpointWriter, DiskCheckpointWriter
|
||||
|
||||
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
|
||||
|
||||
|
||||
def register(name: str):
|
||||
assert name not in _backends, f'"{name}" is registered'
|
||||
|
||||
def wrapper(cls):
|
||||
_backends[name] = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_backend(name: str) -> 'CheckpointIOBackend':
|
||||
assert name in _backends, f'Unsupported backend "{name}"'
|
||||
return _backends[name]()
|
||||
|
||||
|
||||
class CheckpointIOBackend(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.temps: List[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean_temp(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@register('disk')
|
||||
class CheckpointDiskIO(CheckpointIOBackend):
|
||||
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
|
||||
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
return DiskCheckpointReader(base_name)
|
||||
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
temp_dir_name = tempfile.mkdtemp(dir=base_name)
|
||||
self.temps.append(temp_dir_name)
|
||||
return temp_dir_name
|
||||
|
||||
def clean_temp(self) -> None:
|
||||
for temp_dir_name in self.temps:
|
||||
shutil.rmtree(temp_dir_name)
|
|
@ -0,0 +1,9 @@
|
|||
import re
|
||||
|
||||
GLOBAL_META_FILE_NAME = 'global_meta.bin'
|
||||
MODEL_CKPT_FILE_NAME = 'model.bin'
|
||||
OPTIM_CKPT_FILE_NAME = 'optim.bin'
|
||||
META_CKPT_FILE_NAME = 'meta.bin'
|
||||
OTHER_CKPT_FILE_NAME = 'other.bin'
|
||||
|
||||
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
|
|
@ -0,0 +1,227 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .distributed import merge_param, unmerge_param
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
|
||||
|
||||
|
||||
class CheckpointConvertor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ModelCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
|
||||
|
||||
@abstractmethod
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
for k, tensor in state_dict.items():
|
||||
self.buffer[k][rank] = tensor
|
||||
converted_keys = set()
|
||||
for k, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[k]:
|
||||
tensors = []
|
||||
dist_metas = []
|
||||
for rank, tensor in rank_dict.items():
|
||||
tensors.append(tensor)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][k])
|
||||
self.convert_tensors(k, tensors, dist_metas)
|
||||
converted_keys.add(k)
|
||||
for k in converted_keys:
|
||||
del self.buffer[k]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class ModelCheckpointMerger(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
|
||||
super().__init__(param_count)
|
||||
self.sharder = ModelCheckpointSharder(max_shard_size)
|
||||
self.save_fn = save_fn
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
shard = self.sharder.append(key, tensor)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class ModelCheckpointRedistor(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count)
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
nprocs = len(save_fns)
|
||||
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
if len(dist_metas) == 0:
|
||||
# already global
|
||||
tensor = tensors[0]
|
||||
else:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[key][tp_rank][dp_rank]:
|
||||
shard = self.sharders[rank].append(key, t)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
|
||||
paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.param_to_os = param_to_os
|
||||
self.paired_os = paired_os
|
||||
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
|
||||
self.os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
|
||||
@abstractmethod
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
self.setup(state_dict['param_groups'])
|
||||
for idx, state in state_dict['state'].items():
|
||||
self.buffer[idx][rank] = state
|
||||
converted_indices = set()
|
||||
for idx, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
|
||||
states = []
|
||||
dist_metas = []
|
||||
for rank, state in rank_dict.items():
|
||||
states.append(state)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
|
||||
self.convert_states(idx, states, dist_metas)
|
||||
converted_indices.add(idx)
|
||||
for idx in converted_indices:
|
||||
del self.buffer[idx]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fn = save_fn
|
||||
self.sharder = None
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if self.sharder is None:
|
||||
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_state = {}
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
new_state[state_key] = state_tensor
|
||||
shard = self.sharder.append(idx, new_state)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
self.sharders: List[OptimizerCheckpointSharder] = []
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if len(self.sharders) == 0:
|
||||
nprocs = len(self.save_fns)
|
||||
for _ in range(nprocs):
|
||||
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
need_merge: bool = True
|
||||
if len(dist_metas) == 0:
|
||||
need_merge = False
|
||||
else:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_states = [{} for _ in range(len(self.save_fns))]
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
if need_merge:
|
||||
tensor = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
tensor = state_tensor
|
||||
for tp_rank, tensor_list in enumerate(
|
||||
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
|
||||
new_states[rank][state_key] = t
|
||||
else:
|
||||
for new_state in new_states:
|
||||
new_state[state_key] = state_tensor
|
||||
for rank, new_state in enumerate(new_states):
|
||||
shard = self.sharders[rank].append(idx, new_state)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
|
@ -0,0 +1,127 @@
|
|||
import torch
|
||||
from numpy import prod
|
||||
from torch import Tensor
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from .meta import ParamDistMeta, ParamRedistMeta
|
||||
|
||||
|
||||
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
|
||||
if not dist_metas[0].used_zero:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
numel = dist_metas[0].zero_numel
|
||||
orig_shape = dist_metas[0].zero_orig_shape
|
||||
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
|
||||
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
|
||||
return torch.cat(tensors).reshape(orig_shape)
|
||||
|
||||
|
||||
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
|
||||
for t in tensors[1:]:
|
||||
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
|
||||
if not dist_metas[0].used_tp:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
total_parts = prod(dist_meta.tp_num_parts)
|
||||
assert dist_meta.tp_world_size == total_parts, \
|
||||
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for start in range(0, len(tensors), num_parts):
|
||||
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
|
||||
tensors = buffer
|
||||
assert len(tensors) == 1
|
||||
return tensors[0]
|
||||
|
||||
|
||||
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) > 0
|
||||
# check world size
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.dp_world_size == dist_metas[
|
||||
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
|
||||
assert dist_meta.tp_world_size == dist_metas[
|
||||
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
|
||||
|
||||
|
||||
def deduplicate_params(tensors: List[Tensor],
|
||||
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
|
||||
unique_dist_meta = []
|
||||
unique_idx = []
|
||||
for i, dist_meta in enumerate(dist_metas):
|
||||
if dist_meta not in unique_dist_meta:
|
||||
unique_dist_meta.append(dist_meta)
|
||||
unique_idx.append(i)
|
||||
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
|
||||
|
||||
|
||||
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
# validate parallel info
|
||||
validate_parallel_info(dist_metas)
|
||||
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
|
||||
unflattened_tensors = []
|
||||
# group zero params by tp rank
|
||||
tensor_dict = defaultdict(list)
|
||||
dist_meta_dict = defaultdict(list)
|
||||
for t, dist_meta in zip(tensors, dist_metas):
|
||||
tensor_dict[dist_meta.tp_rank].append(t)
|
||||
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
|
||||
assert len(tensor_dict
|
||||
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
|
||||
for tp_rank in tensor_dict.keys():
|
||||
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
|
||||
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
|
||||
|
||||
|
||||
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_tp:
|
||||
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
|
||||
return [tensor]
|
||||
total_parts = prod(redist_meta.tp_num_parts)
|
||||
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
|
||||
tensors = [tensor]
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for t in tensors:
|
||||
assert t.size(dim) % num_parts == 0, \
|
||||
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
|
||||
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
|
||||
buffer.extend(chunks)
|
||||
tensors = buffer
|
||||
assert len(tensors) == redist_meta.tp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_zero:
|
||||
return [tensor] * redist_meta.dp_world_size
|
||||
tensors: List[Optional[Tensor]] = [
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
|
||||
]
|
||||
offsets = redist_meta.zero_offsets + [tensor.numel()]
|
||||
for i, offset in enumerate(offsets[:-1]):
|
||||
end = offsets[i + 1]
|
||||
tensors.append(tensor.view(-1)[offset:end])
|
||||
if len(tensors) < redist_meta.dp_world_size:
|
||||
tensors.extend([
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(redist_meta.dp_world_size - len(tensors))
|
||||
])
|
||||
assert len(tensors) == redist_meta.dp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
|
||||
tensors = split_tp_param(tensor, redist_meta)
|
||||
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
|
||||
return tensors
|
|
@ -0,0 +1,170 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .backend import get_backend
|
||||
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
|
||||
OptimizerCheckpointRedistor)
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import build_checkpoints, optimizer_load_state_dict
|
||||
|
||||
|
||||
def save(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk',
|
||||
**kwargs: Any) -> None:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
if world_size == 1:
|
||||
# global doesn't need dist_meta
|
||||
dist_meta = None
|
||||
else:
|
||||
assert dist_meta is not None
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
|
||||
param_to_os, dist_meta)
|
||||
writer = io_backend.get_writer(path, overwrite, rank, world_size)
|
||||
writer.save_others(kwargs)
|
||||
for model_checkpoint in model_checkpoints:
|
||||
writer.save_model(model_checkpoint)
|
||||
for optimizer_checkpoint in optimizer_checkpoints:
|
||||
writer.save_optimizer(optimizer_checkpoint)
|
||||
writer.save_meta(meta_checkpoint)
|
||||
|
||||
|
||||
def merge(path: str,
|
||||
output_path: str,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
reader = io_backend.get_reader(path)
|
||||
if len(reader.meta_list) == 1:
|
||||
# already global
|
||||
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
|
||||
return False
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
writer = io_backend.get_writer(output_path, overwrite=overwrite)
|
||||
writer.save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
|
||||
dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
|
||||
reader.load_optimizers(), dist_meta_list)
|
||||
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def redist(path: str,
|
||||
output_path: str,
|
||||
redist_meta: RedistMeta,
|
||||
dist_metas: List[Dict[str, ParamDistMeta]],
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
nprocs = len(dist_metas)
|
||||
reader = io_backend.get_reader(path)
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
do_redist: bool = False
|
||||
if len(dist_meta_list) == nprocs:
|
||||
for a, b in zip(dist_metas, dist_meta_list):
|
||||
if a != b:
|
||||
do_redist = True
|
||||
break
|
||||
else:
|
||||
do_redist = True
|
||||
if not do_redist:
|
||||
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
|
||||
return False
|
||||
|
||||
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
|
||||
writers[0].save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(
|
||||
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
|
||||
reader.load_models(), dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
|
||||
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
|
||||
for writer, dist_meta in zip(writers, dist_metas):
|
||||
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
|
||||
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for shard_dict in shard_generator:
|
||||
convertor.append(shard_dict, dist_meta_list)
|
||||
convertor.complete()
|
||||
|
||||
|
||||
def load(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
redist_meta: Optional[RedistMeta] = None,
|
||||
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
backend: str = 'disk') -> dict:
|
||||
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
|
||||
rank: int = dist.get_rank() if dist.is_initialized() else 0
|
||||
is_main_process: bool = rank == 0
|
||||
# validate args
|
||||
if redist_meta is None or dist_metas is None:
|
||||
assert is_global
|
||||
io_backend = get_backend(backend)
|
||||
read_path: str = path
|
||||
if is_main_process:
|
||||
# pre-process checkpoints
|
||||
temp_path = io_backend.get_temp(path)
|
||||
if is_global:
|
||||
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
|
||||
else:
|
||||
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
|
||||
if wrote:
|
||||
read_path = temp_path
|
||||
if not is_global:
|
||||
bcast_list = [read_path] if is_main_process else [None]
|
||||
dist.broadcast_object_list(bcast_list)
|
||||
read_path = bcast_list[0]
|
||||
reader = io_backend.get_reader(read_path)
|
||||
# load model
|
||||
for shard in reader.load_model(rank):
|
||||
model.load_state_dict(shard, strict=False)
|
||||
if optimizer is not None:
|
||||
for shard in reader.load_optimizer(rank):
|
||||
# optimizer.load_state_dict(shard)
|
||||
optimizer_load_state_dict(optimizer, shard)
|
||||
others_dict = reader.load_others()
|
||||
if not is_global:
|
||||
dist.barrier()
|
||||
# clean up temp
|
||||
if is_main_process:
|
||||
io_backend.clean_temp()
|
||||
return others_dict
|
|
@ -0,0 +1,81 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDistMeta:
|
||||
# parallel info
|
||||
dp_rank: int
|
||||
dp_world_size: int
|
||||
tp_rank: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_numel: Optional[int] = None
|
||||
zero_orig_shape: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_numel is not None and self.zero_orig_shape is not None
|
||||
|
||||
@property
|
||||
def parallel_meta(self) -> tuple:
|
||||
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
|
||||
|
||||
@property
|
||||
def tp_meta(self) -> tuple:
|
||||
return self.tp_shard_dims, self.tp_num_parts
|
||||
|
||||
@property
|
||||
def zero_meta(self) -> tuple:
|
||||
return self.zero_numel, self.zero_orig_shape
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: dict) -> 'ParamDistMeta':
|
||||
return ParamDistMeta(**d)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamRedistMeta:
|
||||
# parallel info
|
||||
dp_world_size: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_start_dp_rank: Optional[int] = None
|
||||
zero_offsets: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankRedistMeta:
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRedistMeta:
|
||||
params: Set[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedistMeta:
|
||||
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
|
||||
pipeline_meta: List[PipelineRedistMeta]
|
||||
param_meta: Dict[str, ParamRedistMeta]
|
|
@ -0,0 +1,131 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
|
||||
from .meta import ParamDistMeta
|
||||
from .utils import is_duplicated_list
|
||||
|
||||
|
||||
class CheckpointReader(ABC):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.meta_list = []
|
||||
|
||||
@abstractmethod
|
||||
def read(self, name: str) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_others(self) -> dict:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointReader(CheckpointReader):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
global_meta = self.read(GLOBAL_META_FILE_NAME)
|
||||
for meta_file_name in global_meta['meta']:
|
||||
meta = self.read(meta_file_name)
|
||||
if meta.get('dist_meta', None) is None:
|
||||
# only global checkpoint can have empty dist_meta
|
||||
assert len(global_meta['meta']) == 1
|
||||
self.meta_list.append(meta)
|
||||
|
||||
def read(self, name: str) -> dict:
|
||||
return torch.load(os.path.join(self.base_name, name))
|
||||
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
|
||||
None), meta.get('paired_os', None))
|
||||
for meta in self.meta_list]
|
||||
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
|
||||
# reduce param_count
|
||||
param_count = Counter(p for params in params_list for p in params)
|
||||
# validate param_to_os
|
||||
assert is_duplicated_list(param_to_os_list)
|
||||
assert is_duplicated_list(paired_os_list)
|
||||
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
|
||||
|
||||
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
|
||||
meta = self.meta_list[rank]
|
||||
checkpoint_names = meta.get(shard_type, [])
|
||||
for name in checkpoint_names:
|
||||
yield self.read(name)
|
||||
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
return self._load_shard('model', rank)
|
||||
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
model_checkpoint_names = meta.get('model', [])
|
||||
if indices[i] < len(model_checkpoint_names):
|
||||
shards[i] = self.read(model_checkpoint_names[indices[i]])
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
param_groups = None
|
||||
for shard in self._load_shard('optimizer', rank):
|
||||
if param_groups is None:
|
||||
param_groups = shard['param_groups']
|
||||
else:
|
||||
shard['param_groups'] = param_groups
|
||||
yield shard
|
||||
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
param_groups = []
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
optimizer_checkpoint_names = meta.get('optimizer', [])
|
||||
if indices[i] < len(optimizer_checkpoint_names):
|
||||
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
|
||||
if indices[i] == 0:
|
||||
param_groups.append(shards[i]['param_groups'])
|
||||
else:
|
||||
shards[i]['param_groups'] = param_groups[i]
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_others(self) -> dict:
|
||||
return self.read(OTHER_CKPT_FILE_NAME)
|
|
@ -0,0 +1,223 @@
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .meta import ParamDistMeta
|
||||
|
||||
|
||||
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
|
||||
if arg is not None:
|
||||
return fn(arg)
|
||||
|
||||
|
||||
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
|
||||
# ensure all params in optimizer are in model state dict
|
||||
params_set = set(id(p) for p in model.parameters())
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
assert id(p) in params_set
|
||||
param_mappings = {}
|
||||
start_index = 0
|
||||
|
||||
def get_group_mapping(group):
|
||||
nonlocal start_index
|
||||
param_mappings.update(
|
||||
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
||||
start_index += len(group['params'])
|
||||
|
||||
for g in optimizer.param_groups:
|
||||
get_group_mapping(g)
|
||||
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
|
||||
|
||||
|
||||
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
|
||||
size = 0
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
size += v.numel() * v.element_size()
|
||||
return size
|
||||
|
||||
|
||||
class ModelCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, Tensor] = {}
|
||||
self.buffer_size: int = 0
|
||||
|
||||
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {}
|
||||
self.buffer_size = 0
|
||||
self.buffer[key] = tensor
|
||||
self.buffer_size += tensor.numel() * tensor.element_size()
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
|
||||
shards = []
|
||||
for key, tensor in state_dict.items():
|
||||
shard = self.append(key, tensor)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer) > 0 else None
|
||||
|
||||
|
||||
class OptimizerCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
|
||||
self.buffer_size: int = 0
|
||||
self.returned_first: bool = False
|
||||
|
||||
def append(self, key: int, state: dict) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {'state': {}}
|
||||
self.buffer_size = 0
|
||||
self.buffer['state'][key] = state
|
||||
self.buffer_size += compute_optimizer_state_size(state)
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
|
||||
shards = []
|
||||
for key, state in state_dict['state'].items():
|
||||
shard = self.append(key, state)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer['state']) > 0 else None
|
||||
|
||||
|
||||
def shard_checkpoint(max_shard_size: int,
|
||||
model_state_dict: Dict[str, Tensor],
|
||||
optimizer_state_dict: Optional[dict] = None,
|
||||
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
|
||||
has_optimizer: bool = False
|
||||
if optimizer_state_dict is not None:
|
||||
assert param_to_os is not None
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
for os_key in optimizer_state_dict['state'].keys():
|
||||
assert os_key in os_to_param
|
||||
assert os_to_param[os_key] in model_state_dict
|
||||
has_optimizer = True
|
||||
model_sharder = ModelCheckpointSharder(max_shard_size)
|
||||
model_shards = model_sharder.extend(model_state_dict)
|
||||
run_if_not_none(model_shards.append, model_sharder.complete())
|
||||
if not has_optimizer:
|
||||
return model_shards, []
|
||||
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
|
||||
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
|
||||
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
|
||||
return model_shards, optimizer_shards
|
||||
|
||||
|
||||
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
paired_os = {}
|
||||
for idx, state in optimizer_state_dict['state'].items():
|
||||
paired_os[idx] = {}
|
||||
p = model_state_dict[os_to_param[idx]]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor) and v.shape == p.shape:
|
||||
paired_os[idx][k] = True
|
||||
else:
|
||||
paired_os[idx][k] = False
|
||||
return paired_os
|
||||
|
||||
|
||||
def build_checkpoints(max_size: int,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
|
||||
save_global = dist_meta is None
|
||||
model_state_dict = model.state_dict()
|
||||
optimizer_state_dict = optimizer.state_dict() if optimizer else None
|
||||
meta = {'dist_meta': dist_meta}
|
||||
if optimizer:
|
||||
param_to_os = param_to_os or get_param_to_os(model, optimizer)
|
||||
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
|
||||
meta['param_to_os'] = param_to_os
|
||||
meta['paired_os'] = paired_os
|
||||
if not save_global and eliminate_replica:
|
||||
# filter dp replicated params
|
||||
model_state_dict = {
|
||||
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
if optimizer:
|
||||
optimizer_state_dict['state'] = {
|
||||
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
|
||||
for k in model_state_dict.keys()
|
||||
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
meta['params'] = list(model_state_dict.keys())
|
||||
if len(model_state_dict) == 0:
|
||||
warnings.warn('model state dict is empty, checkpoint is not saved')
|
||||
return [], [], meta
|
||||
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
|
||||
param_to_os)
|
||||
return model_checkpoints, optimizer_checkpoints, meta
|
||||
|
||||
|
||||
def is_duplicated_list(list_: List[Any]) -> bool:
|
||||
if len(list_) == 0:
|
||||
return True
|
||||
elem = list_[0]
|
||||
for x in list_[1:]:
|
||||
if x != elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
|
||||
for k, v in src_state.items():
|
||||
if k in dest_state:
|
||||
old_v = dest_state[k]
|
||||
if isinstance(old_v, Tensor):
|
||||
old_v.copy_(v)
|
||||
else:
|
||||
dest_state[k] = v
|
||||
|
||||
|
||||
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
|
||||
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
|
||||
state_dict = deepcopy(state_dict)
|
||||
groups = optimizer.param_groups
|
||||
saved_groups = state_dict['param_groups']
|
||||
idx_to_p: Dict[str, Parameter] = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
for idx, state in state_dict['state'].items():
|
||||
if idx in idx_to_p:
|
||||
old_state = optimizer.state[idx_to_p[idx]]
|
||||
copy_optimizer_state(state, old_state)
|
||||
else:
|
||||
unexpected_keys.append(idx)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
|
||||
"\n\t".join(error_msgs)))
|
|
@ -0,0 +1,98 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
class CheckpointWriter(ABC):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.overwrite = overwrite
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.is_distributed = world_size > 1
|
||||
self.is_main_process = rank == 0
|
||||
|
||||
@abstractmethod
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointWriter(CheckpointWriter):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__(base_name, overwrite, rank, world_size)
|
||||
if not os.path.exists(base_name):
|
||||
os.makedirs(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
self.model_checkpoint_names = []
|
||||
self.optimizer_checkpoint_names = []
|
||||
self.is_meta_saved: bool = False
|
||||
self._save_global_meta()
|
||||
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
path = os.path.join(self.base_name, name)
|
||||
if os.path.exists(path) and not self.overwrite:
|
||||
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def _save_global_meta(self) -> None:
|
||||
if self.is_main_process:
|
||||
global_meta = {'meta': []}
|
||||
if self.is_distributed:
|
||||
for i in range(self.world_size):
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
|
||||
else:
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME)
|
||||
self.write(GLOBAL_META_FILE_NAME, global_meta)
|
||||
|
||||
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
|
||||
checkpoint_name = base_name
|
||||
if self.is_distributed:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
|
||||
if shard_idx is not None:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
|
||||
return checkpoint_name
|
||||
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save model after saving meta'
|
||||
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
|
||||
self.write(name, model_checkpoint)
|
||||
self.model_checkpoint_names.append(name)
|
||||
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
|
||||
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
|
||||
self.write(name, optimizer_checkpoint)
|
||||
self.optimizer_checkpoint_names.append(name)
|
||||
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
if len(self.model_checkpoint_names) > 0:
|
||||
meta_checkpoint['model'] = self.model_checkpoint_names
|
||||
if len(self.optimizer_checkpoint_names) > 0:
|
||||
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
|
||||
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
|
||||
self.is_meta_saved = True
|
||||
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
if self.is_main_process:
|
||||
self.write(OTHER_CKPT_FILE_NAME, kwargs)
|
|
@ -0,0 +1,120 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.utils import build_checkpoints
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def test_global_model():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
global_state_dict = model_checkpoints[0]
|
||||
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
|
||||
for k, v in orig_state_dict.items():
|
||||
assert torch.equal(v, global_state_dict[k])
|
||||
|
||||
|
||||
def test_global_model_shard():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
|
||||
assert len(model_checkpoints) == 2
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
|
||||
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
|
||||
for k, v in orig_state_dict.items():
|
||||
for state_dict in model_checkpoints:
|
||||
if k in state_dict:
|
||||
assert torch.equal(v, state_dict[k])
|
||||
|
||||
|
||||
def test_global_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
|
||||
for state in meta['paired_os'].values():
|
||||
for k, is_paired in state.items():
|
||||
if k == 'step':
|
||||
assert not is_paired
|
||||
else:
|
||||
assert is_paired
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
state_dict = optimizer_checkpoints[0]
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = state_dict['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v2 == v2
|
||||
assert orig_state_dict['param_groups'] == state_dict['param_groups']
|
||||
|
||||
|
||||
def test_global_optimizer_shard():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 2
|
||||
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
|
||||
optimizer_checkpoints[1]['state'].keys())
|
||||
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
|
||||
'state'] else optimizer_checkpoints[1]['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
|
||||
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
|
||||
|
||||
|
||||
def test_dist_model_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
|
||||
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
|
||||
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_model()
|
||||
test_global_model_shard()
|
||||
test_global_optimizer()
|
||||
test_global_optimizer_shard()
|
||||
test_dist_model_optimizer()
|
|
@ -0,0 +1,188 @@
|
|||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.checkpoint_io.io import load, save
|
||||
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta)
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_gruops:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
p.fill_(scalar)
|
||||
for state in optimizer.state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
v.fill_(scalar)
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
|
||||
def test_save_global_load_global(max_shard_size_gb: float):
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
new_model, new_optimizer = prepare_model_optim()
|
||||
load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
check_model_state_dict(model.state_dict(), new_model.state_dict())
|
||||
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
func()
|
||||
|
||||
|
||||
def launch_dist(fn, world_size: int):
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
|
||||
|
||||
def save_dist(dir_name: str, zero: bool):
|
||||
model, optmizer = prepare_model_optim(shard=True, zero=zero)
|
||||
reset_model_optim(model, optmizer)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank])
|
||||
|
||||
|
||||
def load_and_check_dist(dir_name: str):
|
||||
world_size = dist.get_world_size()
|
||||
model, optmizer = prepare_model_optim(shard=True)
|
||||
reset_model_optim(model, optmizer)
|
||||
model_state_dict = deepcopy(model.state_dict())
|
||||
optimizer_state_dict = deepcopy(optmizer.state_dict())
|
||||
reset_model_optim(model, optmizer, 1)
|
||||
load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size))
|
||||
check_model_state_dict(model_state_dict, model.state_dict())
|
||||
check_optim_state_dict(optimizer_state_dict, optmizer.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_global_load_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
reset_model_optim(model, optimizer)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist_load_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + dp
|
||||
fn = partial(save_dist, dir_name, False)
|
||||
launch_dist(fn, 2)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + zero
|
||||
fn = partial(save_dist, dir_name, True)
|
||||
launch_dist(fn, 4)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_save_global_load_global(80 / 1024**3)
|
||||
test_save_global_load_global(0)
|
||||
test_save_global_load_dist()
|
||||
test_save_dist_load_dist()
|
|
@ -0,0 +1,127 @@
|
|||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import save, merge
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tempfile import TemporaryDirectory
|
||||
from torch.optim import Adam
|
||||
from functools import partial
|
||||
import torch
|
||||
import os
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_merge_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
func()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optmizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
dp_world_size = dist.get_world_size() // 2
|
||||
if not zero:
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
}
|
||||
else:
|
||||
dist_metas = {
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
}
|
||||
save(dir_name, model, optmizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_merge_tp_dp(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 5
|
||||
global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1
|
||||
meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
|
||||
assert meta['dist_meta'] is None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 20
|
||||
optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_merge_global()
|
||||
test_merge_tp_dp(False)
|
||||
test_merge_tp_dp(True)
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
|
||||
|
||||
|
||||
def test_unflatten_zero_param_even() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_unflatten_zero_param_uneven() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_row() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_col() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d_reverse() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_hybrid() -> None:
|
||||
dist_metas = [
|
||||
ParamDistMeta(i % 2,
|
||||
2,
|
||||
i // 2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_numel=4,
|
||||
zero_orig_shape=[2, 2]) for i in range(12)
|
||||
]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_dummy() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, 0, 1)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
merged_tensor = merge_param([orig_tensor], dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unflatten_zero_param_even()
|
||||
test_unflatten_zero_param_uneven()
|
||||
test_gather_tp_param_1d_row()
|
||||
test_gather_tp_param_1d_col()
|
||||
test_gather_tp_param_2d()
|
||||
test_gather_tp_param_2d_reverse()
|
||||
test_merge_param_hybrid()
|
||||
test_merge_param_dummy()
|
|
@ -0,0 +1,149 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import redist, save
|
||||
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta,
|
||||
RedistMeta)
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
def check_checkpoint_shape(dir_name: str):
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
for meta_name in global_meta['meta']:
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta['dist_meta'] is not None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 10
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
|
||||
|
||||
|
||||
def test_global_to_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
func()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optmizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_to_dist(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
if not zero:
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
else:
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_to_dist()
|
||||
test_dist_to_dist(False)
|
||||
test_dist_to_dist(True)
|
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME,
|
||||
OTHER_CKPT_FILE_NAME)
|
||||
from colossalai.utils.checkpoint_io.io import save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_gruops:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_overwrite():
|
||||
model = DummyModel()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
|
||||
save(dir_name, model)
|
||||
|
||||
|
||||
def test_save_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
assert len(os.listdir(dir_name)) == 5
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 1
|
||||
assert len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
check_model_state_dict(model.state_dict(), model_state_dict)
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
|
||||
other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
|
||||
assert len(other_state_dict) == 0
|
||||
|
||||
|
||||
def test_save_global_shard():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
assert len(os.listdir(dir_name)) == 7
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
|
||||
model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
|
||||
assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
|
||||
check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
|
||||
optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
|
||||
assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
|
||||
assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
|
||||
check_optim_state_dict(
|
||||
optimizer.state_dict(), {
|
||||
'state': {
|
||||
**optimizer_state_dicts[0]['state'],
|
||||
**optimizer_state_dicts[1]['state']
|
||||
},
|
||||
'param_groups': optimizer_state_dicts[0]['param_groups']
|
||||
})
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
func()
|
||||
|
||||
|
||||
def run_save_dist(dir_name):
|
||||
model, optmizer = prepare_model_optim()
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
|
||||
'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
|
||||
}
|
||||
save(dir_name, model, optmizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name)
|
||||
world_size = 2
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
assert len(os.listdir(dir_name)) == 8
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 2
|
||||
for rank, meta_name in enumerate(global_meta['meta']):
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta.get('dist_meta', None) is not None
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_overwrite()
|
||||
test_save_global()
|
||||
test_save_global_shard()
|
||||
test_save_dist()
|
|
@ -0,0 +1,137 @@
|
|||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
|
||||
|
||||
|
||||
def test_flatten_zero_param_even() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_flatten_zero_param_uneven() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
|
||||
flat_tensors = flat_tensors[1:-1]
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
|
||||
unmerged_tensors = unmerged_tensors[1:-1]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_split_tp_param_1d_row() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_1d_col() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d_reverse() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_unmerge_param_hybrid() -> None:
|
||||
redist_meta = ParamRedistMeta(2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_start_dp_rank=0,
|
||||
zero_offsets=[0, 1])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
|
||||
for tp_rank in range(6):
|
||||
for dp_rank in range(2):
|
||||
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
|
||||
|
||||
|
||||
def test_unmerge_param_dummy() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 1)
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
|
||||
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flatten_zero_param_even()
|
||||
test_flatten_zero_param_uneven()
|
||||
test_split_tp_param_1d_row()
|
||||
test_split_tp_param_1d_col()
|
||||
test_split_tp_param_2d()
|
||||
test_split_tp_param_2d_reverse()
|
||||
test_unmerge_param_hybrid()
|
||||
test_unmerge_param_dummy()
|
Loading…
Reference in New Issue