mirror of https://github.com/hpcaitech/ColossalAI
[zero] add load_state_dict for sharded model (#894)
* add load_state_dict for sharded model * fix bug * fix bug * fix ckpt dtype and device * support load state dict in zero init ctx * fix bugspull/1040/head
parent
6c5996a56e
commit
7cfd6c827e
|
@ -1,7 +1,6 @@
|
|||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
@ -11,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from contextlib import AbstractContextManager
|
||||
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
@ -128,6 +128,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
|
||||
|
||||
self.module_load_from_state_dict = nn.Module._load_from_state_dict
|
||||
shard_strategy = self.shard_strategy if self.config.shard_param else None
|
||||
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
|
||||
shard_strategy=shard_strategy)
|
||||
self.module_state_dict = nn.Module.state_dict
|
||||
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
|
||||
shard_strategy=shard_strategy,
|
||||
state_dict_func=self.module_state_dict,
|
||||
process_group=self.dp_process_group)
|
||||
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
@ -152,6 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
del self.param_list
|
||||
|
||||
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
||||
nn.Module.load_state_dict = self.module_load_from_state_dict
|
||||
nn.Module.state_dict = self.module_state_dict
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Any, Optional, Iterator, Tuple
|
||||
from copy import deepcopy
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
import itertools
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -377,18 +379,160 @@ class ShardedModelV2(nn.Module):
|
|||
# keep saved_grad in HOLD state
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
return self.module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
return self._colo_state_dict(destination,
|
||||
prefix,
|
||||
keep_vars,
|
||||
shard_strategy=self.shard_strategy,
|
||||
state_dict_func=nn.Module.state_dict,
|
||||
module_to_load=self.module,
|
||||
sharded_params=self.sharded_params,
|
||||
process_group=self.process_group)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
|
||||
for name, p in self.named_parameters():
|
||||
if name in state_dict:
|
||||
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
|
||||
device=p.colo_attr.data_payload.device))
|
||||
# Force re-shard
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
|
||||
elif strict:
|
||||
raise RuntimeError(f'Missing key in state_dict: {name}')
|
||||
|
||||
def _colo_state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
shard_strategy: Optional[BaseShardStrategy] = None,
|
||||
state_dict_func=None,
|
||||
module_to_load=None,
|
||||
sharded_params=[],
|
||||
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
|
||||
if len(sharded_params) == 0:
|
||||
for param in self.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_params.append(param)
|
||||
if shard_strategy is not None:
|
||||
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||
for p in sharded_params:
|
||||
p.data = p.colo_attr.data_payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
module_to_load = module_to_load or self
|
||||
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
|
||||
if shard_strategy is not None:
|
||||
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||
for p in sharded_params:
|
||||
p.colo_attr.set_data_none()
|
||||
return gathered_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
raise NotImplementedError
|
||||
def _colo_load_from_state_dict(self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
shard_strategy=None):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if hasattr(param, 'colo_attr'):
|
||||
param.colo_attr.data_payload_reset(
|
||||
input_param.to(dtype=param.colo_attr.data_payload.dtype,
|
||||
device=param.colo_attr.data_payload.device))
|
||||
if shard_strategy is not None:
|
||||
# Force re-shard
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
|
||||
else:
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(
|
||||
key, input_param.shape, param.shape))
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||
ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
assert isinstance(self.module, nn.ModuleList)
|
||||
|
|
Loading…
Reference in New Issue