diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 8c125db29..8e44c0632 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,7 +1,6 @@ import contextlib import functools from typing import Optional - import torch import torch.nn as nn import torch.distributed as dist @@ -11,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta from colossalai.logging import get_dist_logger from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_param import ShardedParamV2 from contextlib import AbstractContextManager from colossalai.utils import InsertPostInitMethodToModuleSubClasses @@ -128,6 +128,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout + self.module_load_from_state_dict = nn.Module._load_from_state_dict + shard_strategy = self.shard_strategy if self.config.shard_param else None + nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, + shard_strategy=shard_strategy) + self.module_state_dict = nn.Module.state_dict + nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group) + # reserve rng states self.cpu_rng_state = torch.get_rng_state() self.cuda_rng_state = torch.cuda.get_rng_state() @@ -152,6 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): del self.param_list nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout + nn.Module.load_state_dict = self.module_load_from_state_dict + nn.Module.state_dict = self.module_state_dict torch.set_rng_state(self.cpu_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index cc37ddf17..5f087ecab 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,7 +1,9 @@ import functools from collections import OrderedDict -from typing import Any, Optional - +from typing import Any, Optional, Iterator, Tuple +from copy import deepcopy +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +import itertools import torch import torch.distributed as dist import torch.nn as nn @@ -377,18 +379,160 @@ class ShardedModelV2(nn.Module): # keep saved_grad in HOLD state param.colo_attr.saved_grad.trans_state(TensorState.HOLD) + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + return self.module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + return self.module.named_parameters(prefix, recurse) + def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) - for p in self.sharded_params: + return self._colo_state_dict(destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: + for name, p in self.named_parameters(): + if name in state_dict: + p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, + device=p.colo_attr.data_payload.device)) + # Force re-shard + p.colo_attr.sharded_data_tensor.is_sharded = False + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) + elif strict: + raise RuntimeError(f'Missing key in state_dict: {name}') + + def _colo_state_dict(self, + destination=None, + prefix='', + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None) -> 'OrderedDict[str, torch.Tensor]': + if len(sharded_params) == 0: + for param in self.parameters(): + if param.colo_attr.param_is_sharded: + sharded_params.append(param) + if shard_strategy is not None: + shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: p.data = p.colo_attr.data_payload - gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) - for p in self.sharded_params: + module_to_load = module_to_load or self + gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars)) + if shard_strategy is not None: + shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: p.colo_attr.set_data_none() return gathered_state_dict - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): - raise NotImplementedError + def _colo_load_from_state_dict(self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + shard_strategy=None): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if hasattr(param, 'colo_attr'): + param.colo_attr.data_payload_reset( + input_param.to(dtype=param.colo_attr.data_payload.dtype, + device=param.colo_attr.data_payload.device)) + if shard_strategy is not None: + # Force re-shard + param.colo_attr.sharded_data_tensor.is_sharded = False + shard_strategy.shard([param.colo_attr.sharded_data_tensor]) + else: + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format( + key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) def __getitem__(self, idx: int): assert isinstance(self.module, nn.ModuleList)