mirror of https://github.com/hpcaitech/ColossalAI
[ddp] add save/load state dict for ColoDDP (#1127)
* add save/load state dict for ColoDDP * add unit test * refactor unit test folder * polish unit test * rename unit testpull/1138/head
parent
946dbd629d
commit
d26902645e
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
@ -7,8 +8,14 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|||
from colossalai.tensor.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable
|
||||
from typing import Dict, Iterable, List
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
|
@ -122,6 +129,12 @@ class ColoDDP(torch.nn.Module):
|
|||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
return self.module.load_state_dict(state_dict, strict)
|
||||
|
||||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
|
@ -130,7 +143,7 @@ class ColoDDPV2(ColoDDP):
|
|||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params = []
|
||||
self.fp32_params: List[ColoParameter] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
||||
|
@ -205,3 +218,208 @@ class ColoDDPV2(ColoDDP):
|
|||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
r"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
|
||||
Example::
|
||||
|
||||
>>> module.state_dict().keys()
|
||||
['bias', 'weight']
|
||||
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars)
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
if hook_result is not None:
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
chunks = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunks:
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach()
|
||||
for chunk in chunks:
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
||||
|
||||
Returns:
|
||||
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
||||
* **missing_keys** is a list of str containing the missing keys
|
||||
* **unexpected_keys** is a list of str containing the unexpected keys
|
||||
|
||||
Note:
|
||||
If a parameter or buffer is registered as ``None`` and its corresponding key
|
||||
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
||||
``RuntimeError``.
|
||||
"""
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
prefix = ''
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(name, dest_tensor, copy_func):
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
|
||||
ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
def load_fp32_p(fp32_p, data):
|
||||
if fp32_p.storage().size() > 0:
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
|
||||
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
load(name, fp32_p, partial(load_fp32_p, fp32_p))
|
||||
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
load(name, buf, buf.copy_)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state",
|
||||
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ColoDDPV2, ColoDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
|
||||
assert k1 == k2
|
||||
assert torch.allclose(t1, t2, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ColoDDPV2(module, gemini_manager)
|
||||
|
||||
|
||||
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('nested_model')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
torch_model = model_builder().cuda()
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = ddp_init_func(model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_state_dict)
|
||||
state_dict = model.state_dict()
|
||||
check_state_dict_equal(torch_state_dict, state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_state_dict(init_ddp)
|
||||
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=False))
|
||||
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=True))
|
||||
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=False))
|
||||
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=True))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_state_dict(2)
|
Loading…
Reference in New Issue