[zero] add load_state_dict for sharded model (#894)

* add load_state_dict for sharded model

* fix bug

* fix bug

* fix ckpt dtype and device

* support load state dict in zero init ctx

* fix bugs
pull/1040/head
ver217 2022-05-27 10:25:08 +08:00 committed by GitHub
parent 6c5996a56e
commit 7cfd6c827e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 10 deletions

View File

@ -1,7 +1,6 @@
import contextlib
import functools
from typing import Optional
import torch
import torch.nn as nn
import torch.distributed as dist
@ -11,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
@ -128,6 +128,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
self.module_load_from_state_dict = nn.Module._load_from_state_dict
shard_strategy = self.shard_strategy if self.config.shard_param else None
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
shard_strategy=shard_strategy)
self.module_state_dict = nn.Module.state_dict
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
shard_strategy=shard_strategy,
state_dict_func=self.module_state_dict,
process_group=self.dp_process_group)
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
@ -152,6 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del self.param_list
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
nn.Module.load_state_dict = self.module_load_from_state_dict
nn.Module.state_dict = self.module_state_dict
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)

View File

@ -1,7 +1,9 @@
import functools
from collections import OrderedDict
from typing import Any, Optional
from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
import itertools
import torch
import torch.distributed as dist
import torch.nn as nn
@ -377,18 +379,160 @@ class ShardedModelV2(nn.Module):
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
return self.module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
return self.module.named_parameters(prefix, recurse)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
return self._colo_state_dict(destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
for name, p in self.named_parameters():
if name in state_dict:
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
device=p.colo_attr.data_payload.device))
# Force re-shard
p.colo_attr.sharded_data_tensor.is_sharded = False
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
elif strict:
raise RuntimeError(f'Missing key in state_dict: {name}')
def _colo_state_dict(self,
destination=None,
prefix='',
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
if len(sharded_params) == 0:
for param in self.parameters():
if param.colo_attr.param_is_sharded:
sharded_params.append(param)
if shard_strategy is not None:
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.data = p.colo_attr.data_payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
module_to_load = module_to_load or self
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.colo_attr.set_data_none()
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError
def _colo_load_from_state_dict(self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
shard_strategy=None):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if hasattr(param, 'colo_attr'):
param.colo_attr.data_payload_reset(
input_param.to(dtype=param.colo_attr.data_payload.dtype,
device=param.colo_attr.data_payload.device))
if shard_strategy is not None:
# Force re-shard
param.colo_attr.sharded_data_tensor.is_sharded = False
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
else:
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(
key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)