mirror of https://github.com/hpcaitech/ColossalAI
[test] remove useless tests (#4359)
* [test] remove legacy zero test * [test] remove lazy distribute test * [test] remove outdated checkpoint iopull/4386/head
parent
16c0acc01b
commit
16bf4c0221
|
@ -1,2 +0,0 @@
|
|||
from .io import load, merge, redist, save
|
||||
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
|
|
@ -1,74 +0,0 @@
|
|||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .reader import CheckpointReader, DiskCheckpointReader
|
||||
from .writer import CheckpointWriter, DiskCheckpointWriter
|
||||
|
||||
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
|
||||
|
||||
|
||||
def register(name: str):
|
||||
assert name not in _backends, f'"{name}" is registered'
|
||||
|
||||
def wrapper(cls):
|
||||
_backends[name] = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_backend(name: str) -> 'CheckpointIOBackend':
|
||||
assert name in _backends, f'Unsupported backend "{name}"'
|
||||
return _backends[name]()
|
||||
|
||||
|
||||
class CheckpointIOBackend(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.temps: List[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean_temp(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@register('disk')
|
||||
class CheckpointDiskIO(CheckpointIOBackend):
|
||||
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
|
||||
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
return DiskCheckpointReader(base_name)
|
||||
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
temp_dir_name = tempfile.mkdtemp(dir=base_name)
|
||||
self.temps.append(temp_dir_name)
|
||||
return temp_dir_name
|
||||
|
||||
def clean_temp(self) -> None:
|
||||
for temp_dir_name in self.temps:
|
||||
shutil.rmtree(temp_dir_name)
|
|
@ -1,9 +0,0 @@
|
|||
import re
|
||||
|
||||
GLOBAL_META_FILE_NAME = 'global_meta.bin'
|
||||
MODEL_CKPT_FILE_NAME = 'model.bin'
|
||||
OPTIM_CKPT_FILE_NAME = 'optim.bin'
|
||||
META_CKPT_FILE_NAME = 'meta.bin'
|
||||
OTHER_CKPT_FILE_NAME = 'other.bin'
|
||||
|
||||
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
|
|
@ -1,227 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .distributed import merge_param, unmerge_param
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
|
||||
|
||||
|
||||
class CheckpointConvertor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ModelCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
|
||||
|
||||
@abstractmethod
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
for k, tensor in state_dict.items():
|
||||
self.buffer[k][rank] = tensor
|
||||
converted_keys = set()
|
||||
for k, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[k]:
|
||||
tensors = []
|
||||
dist_metas = []
|
||||
for rank, tensor in rank_dict.items():
|
||||
tensors.append(tensor)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][k])
|
||||
self.convert_tensors(k, tensors, dist_metas)
|
||||
converted_keys.add(k)
|
||||
for k in converted_keys:
|
||||
del self.buffer[k]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class ModelCheckpointMerger(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
|
||||
super().__init__(param_count)
|
||||
self.sharder = ModelCheckpointSharder(max_shard_size)
|
||||
self.save_fn = save_fn
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
shard = self.sharder.append(key, tensor)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class ModelCheckpointRedistor(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count)
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
nprocs = len(save_fns)
|
||||
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
if len(dist_metas) == 0:
|
||||
# already global
|
||||
tensor = tensors[0]
|
||||
else:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[key][tp_rank][dp_rank]:
|
||||
shard = self.sharders[rank].append(key, t)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
|
||||
paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.param_to_os = param_to_os
|
||||
self.paired_os = paired_os
|
||||
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
|
||||
self.os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
|
||||
@abstractmethod
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
self.setup(state_dict['param_groups'])
|
||||
for idx, state in state_dict['state'].items():
|
||||
self.buffer[idx][rank] = state
|
||||
converted_indices = set()
|
||||
for idx, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
|
||||
states = []
|
||||
dist_metas = []
|
||||
for rank, state in rank_dict.items():
|
||||
states.append(state)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
|
||||
self.convert_states(idx, states, dist_metas)
|
||||
converted_indices.add(idx)
|
||||
for idx in converted_indices:
|
||||
del self.buffer[idx]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fn = save_fn
|
||||
self.sharder = None
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if self.sharder is None:
|
||||
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_state = {}
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
new_state[state_key] = state_tensor
|
||||
shard = self.sharder.append(idx, new_state)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
self.sharders: List[OptimizerCheckpointSharder] = []
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if len(self.sharders) == 0:
|
||||
nprocs = len(self.save_fns)
|
||||
for _ in range(nprocs):
|
||||
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
need_merge: bool = True
|
||||
if len(dist_metas) == 0:
|
||||
need_merge = False
|
||||
else:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_states = [{} for _ in range(len(self.save_fns))]
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
if need_merge:
|
||||
tensor = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
tensor = state_tensor
|
||||
for tp_rank, tensor_list in enumerate(
|
||||
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
|
||||
new_states[rank][state_key] = t
|
||||
else:
|
||||
for new_state in new_states:
|
||||
new_state[state_key] = state_tensor
|
||||
for rank, new_state in enumerate(new_states):
|
||||
shard = self.sharders[rank].append(idx, new_state)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
|
@ -1,127 +0,0 @@
|
|||
import torch
|
||||
from numpy import prod
|
||||
from torch import Tensor
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from .meta import ParamDistMeta, ParamRedistMeta
|
||||
|
||||
|
||||
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
|
||||
if not dist_metas[0].used_zero:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
numel = dist_metas[0].zero_numel
|
||||
orig_shape = dist_metas[0].zero_orig_shape
|
||||
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
|
||||
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
|
||||
return torch.cat(tensors).reshape(orig_shape)
|
||||
|
||||
|
||||
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
|
||||
for t in tensors[1:]:
|
||||
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
|
||||
if not dist_metas[0].used_tp:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
total_parts = prod(dist_meta.tp_num_parts)
|
||||
assert dist_meta.tp_world_size == total_parts, \
|
||||
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for start in range(0, len(tensors), num_parts):
|
||||
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
|
||||
tensors = buffer
|
||||
assert len(tensors) == 1
|
||||
return tensors[0]
|
||||
|
||||
|
||||
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) > 0
|
||||
# check world size
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.dp_world_size == dist_metas[
|
||||
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
|
||||
assert dist_meta.tp_world_size == dist_metas[
|
||||
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
|
||||
|
||||
|
||||
def deduplicate_params(tensors: List[Tensor],
|
||||
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
|
||||
unique_dist_meta = []
|
||||
unique_idx = []
|
||||
for i, dist_meta in enumerate(dist_metas):
|
||||
if dist_meta not in unique_dist_meta:
|
||||
unique_dist_meta.append(dist_meta)
|
||||
unique_idx.append(i)
|
||||
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
|
||||
|
||||
|
||||
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
# validate parallel info
|
||||
validate_parallel_info(dist_metas)
|
||||
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
|
||||
unflattened_tensors = []
|
||||
# group zero params by tp rank
|
||||
tensor_dict = defaultdict(list)
|
||||
dist_meta_dict = defaultdict(list)
|
||||
for t, dist_meta in zip(tensors, dist_metas):
|
||||
tensor_dict[dist_meta.tp_rank].append(t)
|
||||
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
|
||||
assert len(tensor_dict
|
||||
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
|
||||
for tp_rank in tensor_dict.keys():
|
||||
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
|
||||
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
|
||||
|
||||
|
||||
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_tp:
|
||||
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
|
||||
return [tensor]
|
||||
total_parts = prod(redist_meta.tp_num_parts)
|
||||
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
|
||||
tensors = [tensor]
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for t in tensors:
|
||||
assert t.size(dim) % num_parts == 0, \
|
||||
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
|
||||
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
|
||||
buffer.extend(chunks)
|
||||
tensors = buffer
|
||||
assert len(tensors) == redist_meta.tp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_zero:
|
||||
return [tensor] * redist_meta.dp_world_size
|
||||
tensors: List[Optional[Tensor]] = [
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
|
||||
]
|
||||
offsets = redist_meta.zero_offsets + [tensor.numel()]
|
||||
for i, offset in enumerate(offsets[:-1]):
|
||||
end = offsets[i + 1]
|
||||
tensors.append(tensor.view(-1)[offset:end])
|
||||
if len(tensors) < redist_meta.dp_world_size:
|
||||
tensors.extend([
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(redist_meta.dp_world_size - len(tensors))
|
||||
])
|
||||
assert len(tensors) == redist_meta.dp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
|
||||
tensors = split_tp_param(tensor, redist_meta)
|
||||
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
|
||||
return tensors
|
|
@ -1,170 +0,0 @@
|
|||
import warnings
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .backend import get_backend
|
||||
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
|
||||
OptimizerCheckpointRedistor)
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import build_checkpoints, optimizer_load_state_dict
|
||||
|
||||
|
||||
def save(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk',
|
||||
**kwargs: Any) -> None:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
if world_size == 1:
|
||||
# global doesn't need dist_meta
|
||||
dist_meta = None
|
||||
else:
|
||||
assert dist_meta is not None
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
|
||||
param_to_os, dist_meta)
|
||||
writer = io_backend.get_writer(path, overwrite, rank, world_size)
|
||||
writer.save_others(kwargs)
|
||||
for model_checkpoint in model_checkpoints:
|
||||
writer.save_model(model_checkpoint)
|
||||
for optimizer_checkpoint in optimizer_checkpoints:
|
||||
writer.save_optimizer(optimizer_checkpoint)
|
||||
writer.save_meta(meta_checkpoint)
|
||||
|
||||
|
||||
def merge(path: str,
|
||||
output_path: str,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
reader = io_backend.get_reader(path)
|
||||
if len(reader.meta_list) == 1:
|
||||
# already global
|
||||
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
|
||||
return False
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
writer = io_backend.get_writer(output_path, overwrite=overwrite)
|
||||
writer.save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
|
||||
dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
|
||||
reader.load_optimizers(), dist_meta_list)
|
||||
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def redist(path: str,
|
||||
output_path: str,
|
||||
redist_meta: RedistMeta,
|
||||
dist_metas: List[Dict[str, ParamDistMeta]],
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
nprocs = len(dist_metas)
|
||||
reader = io_backend.get_reader(path)
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
do_redist: bool = False
|
||||
if len(dist_meta_list) == nprocs:
|
||||
for a, b in zip(dist_metas, dist_meta_list):
|
||||
if a != b:
|
||||
do_redist = True
|
||||
break
|
||||
else:
|
||||
do_redist = True
|
||||
if not do_redist:
|
||||
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
|
||||
return False
|
||||
|
||||
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
|
||||
writers[0].save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(
|
||||
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
|
||||
reader.load_models(), dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
|
||||
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
|
||||
for writer, dist_meta in zip(writers, dist_metas):
|
||||
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
|
||||
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for shard_dict in shard_generator:
|
||||
convertor.append(shard_dict, dist_meta_list)
|
||||
convertor.complete()
|
||||
|
||||
|
||||
def load(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
redist_meta: Optional[RedistMeta] = None,
|
||||
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
backend: str = 'disk') -> dict:
|
||||
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
|
||||
rank: int = dist.get_rank() if dist.is_initialized() else 0
|
||||
is_main_process: bool = rank == 0
|
||||
# validate args
|
||||
if redist_meta is None or dist_metas is None:
|
||||
assert is_global
|
||||
io_backend = get_backend(backend)
|
||||
read_path: str = path
|
||||
if is_main_process:
|
||||
# pre-process checkpoints
|
||||
temp_path = io_backend.get_temp(path)
|
||||
if is_global:
|
||||
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
|
||||
else:
|
||||
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
|
||||
if wrote:
|
||||
read_path = temp_path
|
||||
if not is_global:
|
||||
bcast_list = [read_path] if is_main_process else [None]
|
||||
dist.broadcast_object_list(bcast_list)
|
||||
read_path = bcast_list[0]
|
||||
reader = io_backend.get_reader(read_path)
|
||||
# load model
|
||||
for shard in reader.load_model(rank):
|
||||
model.load_state_dict(shard, strict=False)
|
||||
if optimizer is not None:
|
||||
for shard in reader.load_optimizer(rank):
|
||||
# optimizer.load_state_dict(shard)
|
||||
optimizer_load_state_dict(optimizer, shard)
|
||||
others_dict = reader.load_others()
|
||||
if not is_global:
|
||||
dist.barrier()
|
||||
# clean up temp
|
||||
if is_main_process:
|
||||
io_backend.clean_temp()
|
||||
return others_dict
|
|
@ -1,81 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDistMeta:
|
||||
# parallel info
|
||||
dp_rank: int
|
||||
dp_world_size: int
|
||||
tp_rank: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_numel: Optional[int] = None
|
||||
zero_orig_shape: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_numel is not None and self.zero_orig_shape is not None
|
||||
|
||||
@property
|
||||
def parallel_meta(self) -> tuple:
|
||||
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
|
||||
|
||||
@property
|
||||
def tp_meta(self) -> tuple:
|
||||
return self.tp_shard_dims, self.tp_num_parts
|
||||
|
||||
@property
|
||||
def zero_meta(self) -> tuple:
|
||||
return self.zero_numel, self.zero_orig_shape
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: dict) -> 'ParamDistMeta':
|
||||
return ParamDistMeta(**d)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamRedistMeta:
|
||||
# parallel info
|
||||
dp_world_size: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_start_dp_rank: Optional[int] = None
|
||||
zero_offsets: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankRedistMeta:
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRedistMeta:
|
||||
params: Set[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedistMeta:
|
||||
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
|
||||
pipeline_meta: List[PipelineRedistMeta]
|
||||
param_meta: Dict[str, ParamRedistMeta]
|
|
@ -1,131 +0,0 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
|
||||
from .meta import ParamDistMeta
|
||||
from .utils import is_duplicated_list
|
||||
|
||||
|
||||
class CheckpointReader(ABC):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.meta_list = []
|
||||
|
||||
@abstractmethod
|
||||
def read(self, name: str) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_others(self) -> dict:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointReader(CheckpointReader):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
global_meta = self.read(GLOBAL_META_FILE_NAME)
|
||||
for meta_file_name in global_meta['meta']:
|
||||
meta = self.read(meta_file_name)
|
||||
if meta.get('dist_meta', None) is None:
|
||||
# only global checkpoint can have empty dist_meta
|
||||
assert len(global_meta['meta']) == 1
|
||||
self.meta_list.append(meta)
|
||||
|
||||
def read(self, name: str) -> dict:
|
||||
return torch.load(os.path.join(self.base_name, name))
|
||||
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
|
||||
None), meta.get('paired_os', None))
|
||||
for meta in self.meta_list]
|
||||
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
|
||||
# reduce param_count
|
||||
param_count = Counter(p for params in params_list for p in params)
|
||||
# validate param_to_os
|
||||
assert is_duplicated_list(param_to_os_list)
|
||||
assert is_duplicated_list(paired_os_list)
|
||||
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
|
||||
|
||||
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
|
||||
meta = self.meta_list[rank]
|
||||
checkpoint_names = meta.get(shard_type, [])
|
||||
for name in checkpoint_names:
|
||||
yield self.read(name)
|
||||
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
return self._load_shard('model', rank)
|
||||
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
model_checkpoint_names = meta.get('model', [])
|
||||
if indices[i] < len(model_checkpoint_names):
|
||||
shards[i] = self.read(model_checkpoint_names[indices[i]])
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
param_groups = None
|
||||
for shard in self._load_shard('optimizer', rank):
|
||||
if param_groups is None:
|
||||
param_groups = shard['param_groups']
|
||||
else:
|
||||
shard['param_groups'] = param_groups
|
||||
yield shard
|
||||
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
param_groups = []
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
optimizer_checkpoint_names = meta.get('optimizer', [])
|
||||
if indices[i] < len(optimizer_checkpoint_names):
|
||||
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
|
||||
if indices[i] == 0:
|
||||
param_groups.append(shards[i]['param_groups'])
|
||||
else:
|
||||
shards[i]['param_groups'] = param_groups[i]
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_others(self) -> dict:
|
||||
return self.read(OTHER_CKPT_FILE_NAME)
|
|
@ -1,223 +0,0 @@
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .meta import ParamDistMeta
|
||||
|
||||
|
||||
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
|
||||
if arg is not None:
|
||||
return fn(arg)
|
||||
|
||||
|
||||
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
|
||||
# ensure all params in optimizer are in model state dict
|
||||
params_set = set(id(p) for p in model.parameters())
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
assert id(p) in params_set
|
||||
param_mappings = {}
|
||||
start_index = 0
|
||||
|
||||
def get_group_mapping(group):
|
||||
nonlocal start_index
|
||||
param_mappings.update(
|
||||
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
||||
start_index += len(group['params'])
|
||||
|
||||
for g in optimizer.param_groups:
|
||||
get_group_mapping(g)
|
||||
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
|
||||
|
||||
|
||||
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
|
||||
size = 0
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
size += v.numel() * v.element_size()
|
||||
return size
|
||||
|
||||
|
||||
class ModelCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, Tensor] = {}
|
||||
self.buffer_size: int = 0
|
||||
|
||||
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {}
|
||||
self.buffer_size = 0
|
||||
self.buffer[key] = tensor
|
||||
self.buffer_size += tensor.numel() * tensor.element_size()
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
|
||||
shards = []
|
||||
for key, tensor in state_dict.items():
|
||||
shard = self.append(key, tensor)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer) > 0 else None
|
||||
|
||||
|
||||
class OptimizerCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
|
||||
self.buffer_size: int = 0
|
||||
self.returned_first: bool = False
|
||||
|
||||
def append(self, key: int, state: dict) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {'state': {}}
|
||||
self.buffer_size = 0
|
||||
self.buffer['state'][key] = state
|
||||
self.buffer_size += compute_optimizer_state_size(state)
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
|
||||
shards = []
|
||||
for key, state in state_dict['state'].items():
|
||||
shard = self.append(key, state)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer['state']) > 0 else None
|
||||
|
||||
|
||||
def shard_checkpoint(max_shard_size: int,
|
||||
model_state_dict: Dict[str, Tensor],
|
||||
optimizer_state_dict: Optional[dict] = None,
|
||||
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
|
||||
has_optimizer: bool = False
|
||||
if optimizer_state_dict is not None:
|
||||
assert param_to_os is not None
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
for os_key in optimizer_state_dict['state'].keys():
|
||||
assert os_key in os_to_param
|
||||
assert os_to_param[os_key] in model_state_dict
|
||||
has_optimizer = True
|
||||
model_sharder = ModelCheckpointSharder(max_shard_size)
|
||||
model_shards = model_sharder.extend(model_state_dict)
|
||||
run_if_not_none(model_shards.append, model_sharder.complete())
|
||||
if not has_optimizer:
|
||||
return model_shards, []
|
||||
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
|
||||
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
|
||||
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
|
||||
return model_shards, optimizer_shards
|
||||
|
||||
|
||||
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
paired_os = {}
|
||||
for idx, state in optimizer_state_dict['state'].items():
|
||||
paired_os[idx] = {}
|
||||
p = model_state_dict[os_to_param[idx]]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor) and v.shape == p.shape:
|
||||
paired_os[idx][k] = True
|
||||
else:
|
||||
paired_os[idx][k] = False
|
||||
return paired_os
|
||||
|
||||
|
||||
def build_checkpoints(max_size: int,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
|
||||
save_global = dist_meta is None
|
||||
model_state_dict = model.state_dict()
|
||||
optimizer_state_dict = optimizer.state_dict() if optimizer else None
|
||||
meta = {'dist_meta': dist_meta}
|
||||
if optimizer:
|
||||
param_to_os = param_to_os or get_param_to_os(model, optimizer)
|
||||
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
|
||||
meta['param_to_os'] = param_to_os
|
||||
meta['paired_os'] = paired_os
|
||||
if not save_global and eliminate_replica:
|
||||
# filter dp replicated params
|
||||
model_state_dict = {
|
||||
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
if optimizer:
|
||||
optimizer_state_dict['state'] = {
|
||||
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
|
||||
for k in model_state_dict.keys()
|
||||
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
meta['params'] = list(model_state_dict.keys())
|
||||
if len(model_state_dict) == 0:
|
||||
warnings.warn('model state dict is empty, checkpoint is not saved')
|
||||
return [], [], meta
|
||||
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
|
||||
param_to_os)
|
||||
return model_checkpoints, optimizer_checkpoints, meta
|
||||
|
||||
|
||||
def is_duplicated_list(list_: List[Any]) -> bool:
|
||||
if len(list_) == 0:
|
||||
return True
|
||||
elem = list_[0]
|
||||
for x in list_[1:]:
|
||||
if x != elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
|
||||
for k, v in src_state.items():
|
||||
if k in dest_state:
|
||||
old_v = dest_state[k]
|
||||
if isinstance(old_v, Tensor):
|
||||
old_v.copy_(v)
|
||||
else:
|
||||
dest_state[k] = v
|
||||
|
||||
|
||||
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
|
||||
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
|
||||
state_dict = deepcopy(state_dict)
|
||||
groups = optimizer.param_groups
|
||||
saved_groups = state_dict['param_groups']
|
||||
idx_to_p: Dict[str, Parameter] = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
for idx, state in state_dict['state'].items():
|
||||
if idx in idx_to_p:
|
||||
old_state = optimizer.state[idx_to_p[idx]]
|
||||
copy_optimizer_state(state, old_state)
|
||||
else:
|
||||
unexpected_keys.append(idx)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
|
||||
"\n\t".join(error_msgs)))
|
|
@ -1,98 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
class CheckpointWriter(ABC):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.overwrite = overwrite
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.is_distributed = world_size > 1
|
||||
self.is_main_process = rank == 0
|
||||
|
||||
@abstractmethod
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointWriter(CheckpointWriter):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__(base_name, overwrite, rank, world_size)
|
||||
if not os.path.exists(base_name):
|
||||
os.makedirs(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
self.model_checkpoint_names = []
|
||||
self.optimizer_checkpoint_names = []
|
||||
self.is_meta_saved: bool = False
|
||||
self._save_global_meta()
|
||||
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
path = os.path.join(self.base_name, name)
|
||||
if os.path.exists(path) and not self.overwrite:
|
||||
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def _save_global_meta(self) -> None:
|
||||
if self.is_main_process:
|
||||
global_meta = {'meta': []}
|
||||
if self.is_distributed:
|
||||
for i in range(self.world_size):
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
|
||||
else:
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME)
|
||||
self.write(GLOBAL_META_FILE_NAME, global_meta)
|
||||
|
||||
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
|
||||
checkpoint_name = base_name
|
||||
if self.is_distributed:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
|
||||
if shard_idx is not None:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
|
||||
return checkpoint_name
|
||||
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save model after saving meta'
|
||||
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
|
||||
self.write(name, model_checkpoint)
|
||||
self.model_checkpoint_names.append(name)
|
||||
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
|
||||
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
|
||||
self.write(name, optimizer_checkpoint)
|
||||
self.optimizer_checkpoint_names.append(name)
|
||||
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
if len(self.model_checkpoint_names) > 0:
|
||||
meta_checkpoint['model'] = self.model_checkpoint_names
|
||||
if len(self.optimizer_checkpoint_names) > 0:
|
||||
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
|
||||
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
|
||||
self.is_meta_saved = True
|
||||
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
if self.is_main_process:
|
||||
self.write(OTHER_CKPT_FILE_NAME, kwargs)
|
|
@ -1,102 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.common import print_rank_0
|
||||
|
||||
try:
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
except:
|
||||
pass
|
||||
from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
||||
for dim, size in enumerate(shape):
|
||||
if size % 2 == 0:
|
||||
return dim
|
||||
|
||||
|
||||
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
|
||||
shard_dim = find_shard_dim(original_tensor.shape)
|
||||
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||
return target_sharding_spec
|
||||
|
||||
|
||||
def _get_current_name(prefix: str, name: str) -> str:
|
||||
return f'{prefix}.{name}'.lstrip('.')
|
||||
|
||||
|
||||
def generate_sharding_spec_dict(model: nn.Module) -> dict:
|
||||
sharding_spec_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||
# recursively initialize the module
|
||||
for name, mod in module.named_children():
|
||||
generate_recursively(mod, prefix=_get_current_name(prefix, name))
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if isinstance(param, LazyTensor):
|
||||
sharding_spec = make_sharding_spec(param)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if isinstance(buf, LazyTensor):
|
||||
sharding_spec = make_sharding_spec(buf)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
generate_recursively(model)
|
||||
|
||||
return sharding_spec_dict
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
def run_dist_lazy_init(subset, seed: int = 42):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
|
||||
continue
|
||||
print_rank_0(name)
|
||||
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
|
||||
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
|
||||
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port) -> None:
|
||||
colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
|
||||
run_dist_lazy_init()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not SUPPORT_LAZY, reason='torch version should be >= 1.12.0')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_lazy_init():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_lazy_init()
|
|
@ -1,120 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.utils import build_checkpoints
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def test_global_model():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
global_state_dict = model_checkpoints[0]
|
||||
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
|
||||
for k, v in orig_state_dict.items():
|
||||
assert torch.equal(v, global_state_dict[k])
|
||||
|
||||
|
||||
def test_global_model_shard():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
|
||||
assert len(model_checkpoints) == 2
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
|
||||
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
|
||||
for k, v in orig_state_dict.items():
|
||||
for state_dict in model_checkpoints:
|
||||
if k in state_dict:
|
||||
assert torch.equal(v, state_dict[k])
|
||||
|
||||
|
||||
def test_global_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
|
||||
for state in meta['paired_os'].values():
|
||||
for k, is_paired in state.items():
|
||||
if k == 'step':
|
||||
assert not is_paired
|
||||
else:
|
||||
assert is_paired
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
state_dict = optimizer_checkpoints[0]
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = state_dict['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v2 == v2
|
||||
assert orig_state_dict['param_groups'] == state_dict['param_groups']
|
||||
|
||||
|
||||
def test_global_optimizer_shard():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 2
|
||||
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
|
||||
optimizer_checkpoints[1]['state'].keys())
|
||||
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
|
||||
'state'] else optimizer_checkpoints[1]['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
|
||||
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
|
||||
|
||||
|
||||
def test_dist_model_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
|
||||
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
|
||||
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_model()
|
||||
test_global_model_shard()
|
||||
test_global_optimizer()
|
||||
test_global_optimizer_shard()
|
||||
test_dist_model_optimizer()
|
|
@ -1,186 +0,0 @@
|
|||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.io import load, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_groups:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
p.fill_(scalar)
|
||||
for state in optimizer.state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
v.fill_(scalar)
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
|
||||
def test_save_global_load_global(max_shard_size_gb: float):
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
new_model, new_optimizer = prepare_model_optim()
|
||||
load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
check_model_state_dict(model.state_dict(), new_model.state_dict())
|
||||
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def launch_dist(fn, world_size: int):
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
|
||||
|
||||
def save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
reset_model_optim(model, optimizer)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank])
|
||||
|
||||
|
||||
def load_and_check_dist(dir_name: str):
|
||||
world_size = dist.get_world_size()
|
||||
model, optimizer = prepare_model_optim(shard=True)
|
||||
reset_model_optim(model, optimizer)
|
||||
model_state_dict = deepcopy(model.state_dict())
|
||||
optimizer_state_dict = deepcopy(optimizer.state_dict())
|
||||
reset_model_optim(model, optimizer, 1)
|
||||
load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size))
|
||||
check_model_state_dict(model_state_dict, model.state_dict())
|
||||
check_optim_state_dict(optimizer_state_dict, optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_global_load_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
reset_model_optim(model, optimizer)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist_load_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + dp
|
||||
fn = partial(save_dist, dir_name, False)
|
||||
launch_dist(fn, 2)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + zero
|
||||
fn = partial(save_dist, dir_name, True)
|
||||
launch_dist(fn, 4)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_save_global_load_global(80 / 1024**3)
|
||||
test_save_global_load_global(0)
|
||||
test_save_global_load_dist()
|
||||
test_save_dist_load_dist()
|
|
@ -1,126 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import merge, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_merge_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
dp_world_size = dist.get_world_size() // 2
|
||||
if not zero:
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
}
|
||||
else:
|
||||
dist_metas = {
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
}
|
||||
save(dir_name, model, optimizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_merge_tp_dp(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 5
|
||||
global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1
|
||||
meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
|
||||
assert meta['dist_meta'] is None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 20
|
||||
optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_merge_global()
|
||||
test_merge_tp_dp(False)
|
||||
test_merge_tp_dp(True)
|
|
@ -1,101 +0,0 @@
|
|||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
|
||||
|
||||
|
||||
def test_unflatten_zero_param_even() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_unflatten_zero_param_uneven() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_row() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_col() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d_reverse() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_hybrid() -> None:
|
||||
dist_metas = [
|
||||
ParamDistMeta(i % 2,
|
||||
2,
|
||||
i // 2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_numel=4,
|
||||
zero_orig_shape=[2, 2]) for i in range(12)
|
||||
]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_dummy() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, 0, 1)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
merged_tensor = merge_param([orig_tensor], dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unflatten_zero_param_even()
|
||||
test_unflatten_zero_param_uneven()
|
||||
test_gather_tp_param_1d_row()
|
||||
test_gather_tp_param_1d_col()
|
||||
test_gather_tp_param_2d()
|
||||
test_gather_tp_param_2d_reverse()
|
||||
test_merge_param_hybrid()
|
||||
test_merge_param_dummy()
|
|
@ -1,152 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import redist, save
|
||||
from colossalai.utils.checkpoint_io.meta import (
|
||||
ParamDistMeta,
|
||||
ParamRedistMeta,
|
||||
PipelineRedistMeta,
|
||||
RankRedistMeta,
|
||||
RedistMeta,
|
||||
)
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
def check_checkpoint_shape(dir_name: str):
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
for meta_name in global_meta['meta']:
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta['dist_meta'] is not None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 10
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
|
||||
|
||||
|
||||
def test_global_to_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_to_dist(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
if not zero:
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
else:
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_to_dist()
|
||||
test_dist_to_dist(False)
|
||||
test_dist_to_dist(True)
|
|
@ -1,149 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import (
|
||||
GLOBAL_META_FILE_NAME,
|
||||
META_CKPT_FILE_NAME,
|
||||
MODEL_CKPT_FILE_NAME,
|
||||
OTHER_CKPT_FILE_NAME,
|
||||
)
|
||||
from colossalai.utils.checkpoint_io.io import save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_groups:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_overwrite():
|
||||
model = DummyModel()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
|
||||
save(dir_name, model)
|
||||
|
||||
|
||||
def test_save_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
assert len(os.listdir(dir_name)) == 5
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 1
|
||||
assert len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
check_model_state_dict(model.state_dict(), model_state_dict)
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
|
||||
other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
|
||||
assert len(other_state_dict) == 0
|
||||
|
||||
|
||||
def test_save_global_shard():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
assert len(os.listdir(dir_name)) == 7
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
|
||||
model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
|
||||
assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
|
||||
check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
|
||||
optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
|
||||
assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
|
||||
assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
|
||||
check_optim_state_dict(
|
||||
optimizer.state_dict(), {
|
||||
'state': {
|
||||
**optimizer_state_dicts[0]['state'],
|
||||
**optimizer_state_dicts[1]['state']
|
||||
},
|
||||
'param_groups': optimizer_state_dicts[0]['param_groups']
|
||||
})
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name):
|
||||
model, optimizer = prepare_model_optim()
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
|
||||
'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
|
||||
}
|
||||
save(dir_name, model, optimizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name)
|
||||
world_size = 2
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
assert len(os.listdir(dir_name)) == 8
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 2
|
||||
for rank, meta_name in enumerate(global_meta['meta']):
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta.get('dist_meta', None) is not None
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_overwrite()
|
||||
test_save_global()
|
||||
test_save_global_shard()
|
||||
test_save_dist()
|
|
@ -1,137 +0,0 @@
|
|||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
|
||||
|
||||
|
||||
def test_flatten_zero_param_even() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_flatten_zero_param_uneven() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
|
||||
flat_tensors = flat_tensors[1:-1]
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
|
||||
unmerged_tensors = unmerged_tensors[1:-1]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_split_tp_param_1d_row() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_1d_col() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d_reverse() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_unmerge_param_hybrid() -> None:
|
||||
redist_meta = ParamRedistMeta(2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_start_dp_rank=0,
|
||||
zero_offsets=[0, 1])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
|
||||
for tp_rank in range(6):
|
||||
for dp_rank in range(2):
|
||||
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
|
||||
|
||||
|
||||
def test_unmerge_param_dummy() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 1)
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
|
||||
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flatten_zero_param_even()
|
||||
test_flatten_zero_param_uneven()
|
||||
test_split_tp_param_1d_row()
|
||||
test_split_tp_param_1d_col()
|
||||
test_split_tp_param_2d()
|
||||
test_split_tp_param_2d_reverse()
|
||||
test_unmerge_param_hybrid()
|
||||
test_unmerge_param_dummy()
|
|
@ -1,140 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
|
||||
LOGGER = get_dist_logger('zero_test')
|
||||
|
||||
MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None)))
|
||||
|
||||
_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
fp32_reduce_scatter=False,
|
||||
tensor_placement_policy='cuda',
|
||||
gradient_predivide_factor=1.0,
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale=2**32)
|
||||
|
||||
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(
|
||||
model_config=_ZERO_MODEL_CONFIG,
|
||||
optimizer_config=_ZERO_OPTIMIZER_CONFIG,
|
||||
),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
CONFIG = dict(fp16=dict(mode=None,),
|
||||
zero=dict(level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
|
||||
offload_param_config=dict(device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9)),
|
||||
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
grad = p.grad.float()
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
# assert p.dtype == zero_p.dtype
|
||||
assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}"
|
||||
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
if zero_p.colo_attr.is_replicated:
|
||||
zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank].float()
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
else:
|
||||
zero_grad = zero_p.colo_attr.grad_payload
|
||||
grad = p.grad.to(zero_grad.dtype)
|
||||
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
|
||||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank].float()
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.data_payload.to(p.device)
|
||||
|
||||
assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
|
@ -1,67 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from common import CONFIG
|
||||
from test_sharded_optim_v2 import _run_step
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=True,
|
||||
)
|
||||
|
||||
sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 1:
|
||||
break
|
||||
assert zero_model.overflow_counter == 0
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
for param in zero_model.parameters():
|
||||
assert not has_inf_or_nan(param.colo_attr.data_payload)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_test_found_inf()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_found_inf(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_found_inf(world_size=2)
|
|
@ -1,75 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@clear_cache_before_run()
|
||||
def test_gemini_manager():
|
||||
# reset the manager, in case that there exists memory information left
|
||||
manager = StatefulTensor.GST_MGR
|
||||
manager.reset()
|
||||
|
||||
# occupation 8
|
||||
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||
# occupation 60
|
||||
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||
|
||||
# occupation 28
|
||||
t1 = torch.empty(7, device='cuda')
|
||||
# occupation 12
|
||||
t2 = torch.empty(3, device='cpu')
|
||||
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||
st4 = StatefulTensor(None, TensorState.FREE)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 60
|
||||
assert manager.total_mem['cuda'] == 36
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||
|
||||
st4.payload_reset(t2)
|
||||
st3.payload_reset(t2)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 84
|
||||
assert manager.total_mem['cuda'] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
|
||||
st1.move_to(torch.device('cpu'))
|
||||
st2.move_to(torch.device('cpu'))
|
||||
st3.move_to(torch.device('cuda', 0))
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
|
||||
st1.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_manager()
|
|
@ -1,73 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from common import CONFIG
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
for name, get_components_func in non_distributed_component_funcs._registry.items():
|
||||
# because the ZeroInitContext automatically turns parameters to fp16
|
||||
# and the beit model use tensor.erfinv_() function to initialize weights
|
||||
# tensor.erfinv_() doesn't support Half in CPU, we omit the beit model
|
||||
if name == 'beit':
|
||||
continue
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
if init_device_type == 'cuda':
|
||||
init_device = get_current_device()
|
||||
elif init_device_type == 'cpu':
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
continue
|
||||
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.data_payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, _ = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6
|
||||
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_init_context(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init_context(1)
|
|
@ -1,82 +0,0 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def run_model(model, inputs, label, criterion, use_param_hook=False):
|
||||
if use_param_hook:
|
||||
|
||||
class HooKWrapper:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hook_triggered_times = 0
|
||||
|
||||
def wrapper_func(self):
|
||||
|
||||
def hook(param, grad) -> torch.Tensor or None:
|
||||
self.hook_triggered_times += 1
|
||||
return grad
|
||||
|
||||
return hook
|
||||
|
||||
hookwrapper = HooKWrapper()
|
||||
param_list = [p for p in model.parameters()]
|
||||
hook_mgr = BaseParamHookMgr(param_list)
|
||||
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
if criterion:
|
||||
y = model(inputs)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(inputs, label)
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
|
||||
if use_param_hook:
|
||||
hook_mgr.remove_hooks()
|
||||
return hookwrapper.hook_triggered_times
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_base_param_hook():
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model']
|
||||
# test_models = ['bert']
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = model_builder(checkpoint=True).cuda()
|
||||
model.train()
|
||||
|
||||
for i, (inputs, label) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
run_model(model, inputs.cuda(), label.cuda(), criterion, False)
|
||||
ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True)
|
||||
|
||||
# Make sure param hook has only be fired once in case of parameter sharing
|
||||
assert ret2 == len(list(model.parameters()))
|
||||
|
||||
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
|
||||
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_base_param_hook()
|
|
@ -1,64 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model']
|
||||
shard_strategy = shard_strategy_class()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
|
||||
model = DDP(model, device_ids=[torch.cuda.current_device()])
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, enable_autocast)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
|
||||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_model_v2(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2(world_size=2)
|
|
@ -1,91 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from common import CONFIG, allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_param import ShardedTensor
|
||||
from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
# test shard strategy
|
||||
shard_strategy.shard([t])
|
||||
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
|
||||
shard_strategy.gather([t])
|
||||
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_shard_tensor_with_strategy(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_tensor(world_size):
|
||||
spawn(_run_shard_tensor, world_size)
|
||||
|
||||
|
||||
def _run_shard_param_v2(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
param = torch.nn.Parameter(torch.randn(2, 3))
|
||||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param)
|
||||
|
||||
allclose(sparam.data_payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
sparam.set_data_none()
|
||||
assert (param.data.numel() == 0)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.set_data_none()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# append a grad to torch param
|
||||
param.data = sparam.data_payload
|
||||
param.grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.saved_grad = StatefulTensor(param.grad)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_shard_param_v2(world_size):
|
||||
spawn(_run_shard_param_v2, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_shard_tensor(2)
|
||||
test_shard_param_v2(2)
|
|
@ -1,89 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
|
||||
def init_zero(model_builder, placement_policy):
|
||||
device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu')
|
||||
shard_strategy = TensorShardStrategy()
|
||||
with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True):
|
||||
model = model_builder()
|
||||
model = ShardedModelV2(
|
||||
model,
|
||||
shard_strategy,
|
||||
tensor_placement_policy=placement_policy,
|
||||
reuse_fp16_shard=True,
|
||||
)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ShardedOptimizerV2(model, optim, initial_scale=32)
|
||||
return model, optim
|
||||
|
||||
|
||||
def run_step(model, optim, criterion, data, label):
|
||||
optim.zero_grad()
|
||||
logits = model(data)
|
||||
loss = criterion(logits, label)
|
||||
optim.backward(loss)
|
||||
optim.step()
|
||||
|
||||
|
||||
def check_state_dict_eq(state_dict, other):
|
||||
for p, state in state_dict['state'].items():
|
||||
other_state = other['state'][p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}'
|
||||
else:
|
||||
assert v == other_state[k]
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def run_nested_model(placement_policy):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
model, optim = init_zero(model_builder, placement_policy)
|
||||
set_seed(42)
|
||||
model_copy, optim_copy = init_zero(model_builder, placement_policy)
|
||||
|
||||
model.train()
|
||||
model_copy.train()
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
run_step(model, optim, criterion, data, label)
|
||||
optim_copy.load_state_dict(optim.state_dict())
|
||||
check_state_dict_eq(optim.state_dict(), optim_copy.state_dict())
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
run_step(model_copy, optim_copy, criterion, data, label)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_nested_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_optim_state_dist(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_state_dist(2)
|
|
@ -1,110 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from common import CONFIG, check_sharded_model_params
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
|
||||
return
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'auto',
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda().float()
|
||||
|
||||
if use_cpuadam:
|
||||
optimizer_class = CPUAdam
|
||||
optim = optimizer_class(model.parameters(), lr=1e-3)
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
if dist.get_world_size() > 1:
|
||||
apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()])
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||
for param in model.parameters():
|
||||
assert not has_inf_or_nan(param)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_test_sharded_optim_v2()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_optim_v2(world_size):
|
||||
spawn(_run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2(world_size=2)
|
|
@ -1,87 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchvision.models import resnet50
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# this test only runs on resnet18
|
||||
# as this model has sync batch normalization
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
|
||||
colossalai.launch(config=dict(zero=zero_config, cudnn_deterministic=True, cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
model = resnet50()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
||||
|
||||
# train for dummy iterations
|
||||
engine.train()
|
||||
for _ in range(2):
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
label = torch.randint(0, 10, size=(4,)).cuda()
|
||||
engine.zero_grad()
|
||||
out = engine(data)
|
||||
loss = engine.criterion(out, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# test
|
||||
# need to make sure the batch norm stats are synchronized
|
||||
# so that given the same input, the model will produce the same
|
||||
# output on different ranks
|
||||
engine.eval()
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# predict
|
||||
out = engine(data)
|
||||
|
||||
# test if results are equal
|
||||
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
||||
tensor_list.insert(rank, out)
|
||||
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
||||
'expected the output from different ranks to be the same, but got different values'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
This test is to make sure that buffers are synchronized between ranks
|
||||
when using ZeRO. An example of module buffer is the running stats of
|
||||
BatchNormalization layer, i.e. mean and var.
|
||||
|
||||
If the buffers are not synchronized, the model will produce different
|
||||
output even though the input and parameters are the same. This is not
|
||||
wanted if we are doing predictions.
|
||||
|
||||
"""
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_with_sync_bn()
|
|
@ -1,55 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from common import CONFIG
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_zero_state_dict(shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
shard_strategy = shard_strategy_class()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
|
||||
zero_state_dict = zero_model.state_dict()
|
||||
for key, val in model.state_dict().items():
|
||||
assert torch.equal(val, zero_state_dict[key].to(val.device))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_zero_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_state_dict(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_state_dict(2)
|
|
@ -1,94 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import (
|
||||
colo_model_data_move_to_cpu,
|
||||
colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline,
|
||||
colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage,
|
||||
)
|
||||
|
||||
|
||||
def _run_colo_tensor_mem_usage():
|
||||
for i in range(1):
|
||||
if i == 1:
|
||||
t1 = StatefulTensor(torch.randn(2, 2))
|
||||
t2 = StatefulTensor(torch.randn(4, 4))
|
||||
c1, g1 = colo_tensor_mem_usage(t1)
|
||||
c2, g2 = colo_tensor_mem_usage(t2)
|
||||
assert c1 * 4 == c2
|
||||
assert g1 * 4 == g2
|
||||
else:
|
||||
t1 = torch.randn(2, 2)
|
||||
t2 = torch.randn(4, 4)
|
||||
c1, g1 = colo_tensor_mem_usage(t1)
|
||||
c2, g2 = colo_tensor_mem_usage(t2)
|
||||
assert c1 * 4 == c2
|
||||
assert g1 * 4 == g2
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move_inline():
|
||||
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
assert t.device == get_current_device()
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move():
|
||||
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
|
||||
(torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
|
||||
cpu_t, cuda_t = t
|
||||
colo_model_data_tensor_move(cpu_t, cuda_t)
|
||||
assert cuda_t.device == get_current_device()
|
||||
|
||||
|
||||
def _run_colo_model_data_move_to_cpu():
|
||||
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
|
||||
colo_model_data_move_to_cpu(t)
|
||||
assert t.device == torch.device("cpu")
|
||||
|
||||
|
||||
def _run_colo_model_tensor_clone():
|
||||
for t in [
|
||||
StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
|
||||
torch.randn(4, 4).cuda(torch.cuda.current_device())
|
||||
]:
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
assert t.payload.device == get_current_device()
|
||||
else:
|
||||
assert t.device == get_current_device()
|
||||
p = colo_model_tensor_clone(t, get_current_device())
|
||||
assert p.device == get_current_device()
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
assert t.payload.device == p.device
|
||||
assert t.payload[i][j] == p[i][j]
|
||||
else:
|
||||
assert t.device == p.device
|
||||
assert t[i][j] == p[i][j]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
_run_colo_tensor_mem_usage()
|
||||
_run_colo_model_data_tensor_move_inline()
|
||||
_run_colo_model_data_tensor_move()
|
||||
_run_colo_model_data_move_to_cpu()
|
||||
_run_colo_model_tensor_clone()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_tensor_utils(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_tensor_utils(world_size=2)
|
|
@ -1,113 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, parallel_config, bf16):
|
||||
is_mp_config = parallel_config == MP_PARALLEL_CONFIG
|
||||
is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
|
||||
if bf16:
|
||||
parallel_config['zero']['model_config']['bf16'] = True
|
||||
colossalai.launch(config=parallel_config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True,
|
||||
bf16=bf16):
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
|
||||
optimizer=colo_optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
dtype = torch.bfloat16 if bf16 else torch.float16
|
||||
torch_model = model_builder(checkpoint=True).to(dtype)
|
||||
col_model_deepcopy(engine.model, torch_model)
|
||||
torch_model = torch_model.cuda().float()
|
||||
|
||||
engine.train()
|
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()])
|
||||
|
||||
i = 0
|
||||
for data, label in train_dataloader:
|
||||
if i > 4:
|
||||
break
|
||||
|
||||
data, label = data.cuda(), label.cuda()
|
||||
|
||||
engine.zero_grad()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
if criterion:
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
torch_output = torch_model(data)
|
||||
torch_loss = engine.criterion(torch_output, label)
|
||||
else:
|
||||
loss = engine(data, label)
|
||||
torch_loss = torch_model(data, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
torch_loss.backward()
|
||||
|
||||
for param in torch_model.parameters():
|
||||
if param.grad is not None:
|
||||
assert not has_inf_or_nan(param.grad)
|
||||
|
||||
torch_optimizer.step()
|
||||
i += 1
|
||||
|
||||
if is_mp_config:
|
||||
check_params(torch_model, colo_model, loose=True)
|
||||
elif is_zero_config:
|
||||
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||
|
||||
|
||||
# FIXME: enable this test in next PR
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mp_engine(world_size):
|
||||
spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("bf16", [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_engine(world_size, bf16):
|
||||
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_engine(world_size=4)
|
Loading…
Reference in New Issue