[zero] add load_state_dict for sharded model (#894)

* add load_state_dict for sharded model

* fix bug

* fix bug

* fix ckpt dtype and device

* support load state dict in zero init ctx

* fix bugs
pull/1040/head
ver217 2022-05-27 10:25:08 +08:00 committed by GitHub
parent 6c5996a56e
commit 7cfd6c827e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 10 deletions

View File

@ -1,7 +1,6 @@
import contextlib import contextlib
import functools import functools
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
@ -11,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses from colossalai.utils import InsertPostInitMethodToModuleSubClasses
@ -128,6 +128,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
self.module_load_from_state_dict = nn.Module._load_from_state_dict
shard_strategy = self.shard_strategy if self.config.shard_param else None
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
shard_strategy=shard_strategy)
self.module_state_dict = nn.Module.state_dict
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
shard_strategy=shard_strategy,
state_dict_func=self.module_state_dict,
process_group=self.dp_process_group)
# reserve rng states # reserve rng states
self.cpu_rng_state = torch.get_rng_state() self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state() self.cuda_rng_state = torch.cuda.get_rng_state()
@ -152,6 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del self.param_list del self.param_list
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
nn.Module.load_state_dict = self.module_load_from_state_dict
nn.Module.state_dict = self.module_state_dict
torch.set_rng_state(self.cpu_rng_state) torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state)

View File

@ -1,7 +1,9 @@
import functools import functools
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
import itertools
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -377,18 +379,160 @@ class ShardedModelV2(nn.Module):
# keep saved_grad in HOLD state # keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD) param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
return self.module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
return self.module.named_parameters(prefix, recurse)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) return self._colo_state_dict(destination,
for p in self.sharded_params: prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
for name, p in self.named_parameters():
if name in state_dict:
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
device=p.colo_attr.data_payload.device))
# Force re-shard
p.colo_attr.sharded_data_tensor.is_sharded = False
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
elif strict:
raise RuntimeError(f'Missing key in state_dict: {name}')
def _colo_state_dict(self,
destination=None,
prefix='',
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
if len(sharded_params) == 0:
for param in self.parameters():
if param.colo_attr.param_is_sharded:
sharded_params.append(param)
if shard_strategy is not None:
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.data = p.colo_attr.data_payload p.data = p.colo_attr.data_payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) module_to_load = module_to_load or self
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
for p in self.sharded_params: if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.colo_attr.set_data_none() p.colo_attr.set_data_none()
return gathered_state_dict return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): def _colo_load_from_state_dict(self,
raise NotImplementedError state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
shard_strategy=None):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if hasattr(param, 'colo_attr'):
param.colo_attr.data_payload_reset(
input_param.to(dtype=param.colo_attr.data_payload.dtype,
device=param.colo_attr.data_payload.device))
if shard_strategy is not None:
# Force re-shard
param.colo_attr.sharded_data_tensor.is_sharded = False
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
else:
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(
key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList) assert isinstance(self.module, nn.ModuleList)