mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
530 lines
24 KiB
530 lines
24 KiB
import torch
|
|
import itertools
|
|
import torch.distributed as dist
|
|
from functools import partial
|
|
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|
from colossalai.gemini.chunk import TensorState, Chunk
|
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
from typing import Dict, Iterable, List, Optional, Set
|
|
from colossalai.logging import get_dist_logger
|
|
from collections import OrderedDict
|
|
from colossalai.tensor.colo_parameter import ColoParameter
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
|
from .reducer import Reducer
|
|
|
|
try:
|
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
|
except ImportError:
|
|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
"""Free underlying storage of a Tensor."""
|
|
if data.storage().size() > 0:
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
# is the sole occupant of the Storage.
|
|
assert data.storage_offset() == 0
|
|
data.storage().resize_(0)
|
|
|
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
args = args.to(dtype)
|
|
elif isinstance(args, (list, tuple)):
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
elif isinstance(args, dict):
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
return args
|
|
|
|
|
|
class ColoDDP(torch.nn.Module):
|
|
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
|
|
|
|
Example:
|
|
>>> from colossalai.core import global_context as gpc
|
|
>>> from colossalai.context import ParallelMode
|
|
>>> model = torch.nn.Linear(20, 1)
|
|
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
|
>>> model = ColoDDP(model, pg)
|
|
>>> logits = model(x)
|
|
>>> loss = criterion(logits, labels)
|
|
>>> model.backward(loss)
|
|
|
|
Args:
|
|
module (torch.nn.Module): Module to apply DDP.
|
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
|
If it's None, the default data parallel group will be used. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module: torch.nn.Module,
|
|
process_group: ColoProcessGroup,
|
|
bucket_cap_mb: int = 25,
|
|
rebuild_bucket: bool = True) -> None:
|
|
assert not isinstance(module, ColoDDP)
|
|
super().__init__()
|
|
self.module = module
|
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
|
assert process_group
|
|
|
|
self.process_group = process_group
|
|
self.dp_world_size = self.process_group.dp_world_size()
|
|
|
|
self.reducer = Reducer(bucket_cap_mb)
|
|
self.rebuild_bucket = rebuild_bucket
|
|
for p in module.parameters():
|
|
if getattr(p, '_ddp_to_ignore', False):
|
|
continue
|
|
if p.requires_grad:
|
|
p.register_hook(partial(self.grad_handle, p))
|
|
|
|
def parameters(self, recurse: bool = True):
|
|
return self.module.parameters(recurse)
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
|
return self.module.named_parameters(prefix, recurse)
|
|
|
|
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
|
return self.module.named_buffers(prefix, recurse)
|
|
|
|
def named_children(self):
|
|
return self.module.named_children()
|
|
|
|
def named_modules(self,
|
|
memo: Optional[Set[torch.nn.Module]] = None,
|
|
prefix: str = '',
|
|
remove_duplicate: bool = True):
|
|
return self.module.named_modules(memo, prefix, remove_duplicate)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
self.module.zero_grad(set_to_none=True)
|
|
return self.module(*args, **kwargs)
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
loss.backward()
|
|
with torch.cuda.stream(self.comm_stream):
|
|
self.reducer.flush()
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
|
if self.rebuild_bucket:
|
|
self.reducer.free()
|
|
for p in self.module.parameters():
|
|
if getattr(p, '_ddp_to_ignore', False):
|
|
continue
|
|
if p.grad.device.type != "cpu":
|
|
p.grad = p._saved_grad
|
|
|
|
def grad_handle(self, p, grad):
|
|
if grad.device.type != "cpu":
|
|
empty_grad = torch.empty_like(grad)
|
|
free_storage(empty_grad)
|
|
if self.dp_world_size > 1:
|
|
grad = grad / self.dp_world_size
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(self.comm_stream):
|
|
self.reducer.all_reduce_async(grad,
|
|
group=self.process_group.dp_process_group(),
|
|
callback_fn=partial(self._save_grad, p))
|
|
grad.record_stream(self.comm_stream)
|
|
else:
|
|
ColoDDP._save_grad(p, grad)
|
|
return empty_grad
|
|
|
|
else:
|
|
# TODO(jiaruifang) fixme
|
|
self.process_group.set_cpu_groups()
|
|
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
|
return grad
|
|
|
|
@staticmethod
|
|
def _save_grad(p, grad):
|
|
if hasattr(p, '_saved_grad'):
|
|
p._saved_grad.add_(grad)
|
|
else:
|
|
p._saved_grad = grad
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
self.module.zero_grad(set_to_none=True)
|
|
for p in self.module.parameters():
|
|
if getattr(p, '_saved_grad', None) is not None:
|
|
if set_to_none:
|
|
p._saved_grad = None
|
|
else:
|
|
if p._saved_grad.grad_fn is not None:
|
|
p._saved_grad.detach_()
|
|
else:
|
|
p._saved_grad.requires_grad_(False)
|
|
p._saved_grad.zero_()
|
|
|
|
@staticmethod
|
|
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
|
"""Sets parameters to be ignored by DDP.
|
|
This method must be called before initializing ColoDDP.
|
|
|
|
Example:
|
|
>>> params_to_ignore = []
|
|
>>> for p in module.parameters():
|
|
>>> if should_ignore(p):
|
|
>>> params_to_ignore.append(p)
|
|
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
|
>>> module = ColoDDP(module)
|
|
|
|
Args:
|
|
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
|
"""
|
|
for p in params_to_ignore:
|
|
p._ddp_to_ignore = True
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
|
return self.module.load_state_dict(state_dict, strict)
|
|
|
|
|
|
class ZeroDDP(ColoDDP):
|
|
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
|
|
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
|
|
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
|
|
|
Example:
|
|
>>> model = torch.nn.Linear(20, 1)
|
|
>>> placement_policy = 'cuda'
|
|
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
|
|
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
|
|
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
>>> model = ZeroDDP(model, gemini_manager)
|
|
>>> logits = model(x)
|
|
>>> loss = criterion(logits, labels)
|
|
>>> model.backward(loss)
|
|
|
|
Args:
|
|
module (torch.nn.Module): Module to apply ZeRO-DP.
|
|
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
|
For more details, see the API reference of ``GeminiManager``.
|
|
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module: torch.nn.Module,
|
|
gemini_manager: GeminiManager,
|
|
force_outputs_fp32: bool = False) -> None:
|
|
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
|
self.gemini_manager = gemini_manager
|
|
self.chunk_manager = gemini_manager.chunk_manager
|
|
self.force_outputs_fp32 = force_outputs_fp32
|
|
self.param_op_hook = ZeROHookV2(gemini_manager)
|
|
self.fp32_params: List[ColoParameter] = []
|
|
self.overflow_counter = 0
|
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
|
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
|
self.chunk_manager.create_group('fp32_param')
|
|
# TODO: get param order and filter unused params
|
|
for p in module.parameters():
|
|
if getattr(p, '_ddp_to_ignore', False):
|
|
p.data = p.half()
|
|
continue
|
|
fp32_p = p.float().detach()
|
|
p.data = p.half()
|
|
self.chunk_manager.append_tensor(p, 'fp16_param')
|
|
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
|
self.fp32_params.append(fp32_p)
|
|
self.grads_device[p] = self.gemini_manager.default_device
|
|
self._cast_buffers()
|
|
self._logger = get_dist_logger()
|
|
|
|
def forward(self, *args, **kwargs):
|
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
|
self.module.zero_grad(set_to_none=True)
|
|
self.gemini_manager.pre_iter()
|
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
|
outputs = self.module(*args, **kwargs)
|
|
self.chunk_manager.exec_lazy_release()
|
|
if self.force_outputs_fp32:
|
|
return _cast_float(outputs, torch.float)
|
|
return outputs
|
|
|
|
def _setup_grads_ptr(self):
|
|
for p in self.module.parameters():
|
|
if getattr(p, '_ddp_to_ignore', False):
|
|
continue
|
|
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
|
p.grad = None
|
|
else:
|
|
p.grad = p.data
|
|
|
|
def _post_backward(self):
|
|
self.chunk_manager.exec_lazy_release()
|
|
self._setup_grads_ptr()
|
|
self._logger.debug(
|
|
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
|
)
|
|
self.gemini_manager.post_iter()
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
|
loss.backward()
|
|
self._post_backward()
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
|
torch.autograd.backward(tensor, grad)
|
|
self._post_backward()
|
|
|
|
def grad_handle(self, p, grad):
|
|
empty_grad = torch.empty_like(grad)
|
|
free_storage(empty_grad)
|
|
with torch._C.DisableTorchFunction():
|
|
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
|
if self.dp_world_size > 1:
|
|
grad = grad / self.dp_world_size
|
|
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
|
chunk = self.chunk_manager.get_chunk(p)
|
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
|
self.chunk_manager.release_chunk(chunk)
|
|
if reduced and not chunk.is_empty:
|
|
self.overflow_counter += chunk.has_inf_or_nan
|
|
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
|
return empty_grad
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
|
for tensor in chunk.get_tensors():
|
|
self.grads_device[tensor] = device
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
|
r"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
Parameters and buffers set to ``None`` are not included.
|
|
|
|
Returns:
|
|
dict:
|
|
a dictionary containing a whole state of the module
|
|
|
|
Example:
|
|
|
|
>>> module.state_dict().keys()
|
|
['bias', 'weight']
|
|
|
|
"""
|
|
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
|
record_flag = (not only_rank_0) or is_rank_0
|
|
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
destination._metadata = OrderedDict()
|
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
|
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
|
|
|
|
for hook in self._state_dict_hooks.values():
|
|
hook_result = hook(self, destination, prefix, local_metadata)
|
|
if hook_result is not None:
|
|
destination = hook_result
|
|
return destination
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
|
|
r"""Saves module state to `destination` dictionary, containing a state
|
|
of the module, but not its descendants. This is called on every
|
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by
|
|
overriding this method with custom logic.
|
|
|
|
Args:
|
|
destination (dict): a dict where state will be stored
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
"""
|
|
# save parameters
|
|
param_to_save_data = dict()
|
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
|
for chunk in chunk_list:
|
|
# record the original device of the chunk
|
|
org_chunk_dev_typ = chunk.device_type
|
|
self.chunk_manager.access_chunk(chunk)
|
|
|
|
for tensor in chunk.get_tensors():
|
|
rec_p = torch.empty([0])
|
|
if record_flag:
|
|
rec_p = tensor.cpu() # move the whole tensor to CPU mem
|
|
assert tensor not in param_to_save_data
|
|
param_to_save_data[tensor] = rec_p
|
|
# release the actual memory of the chunk
|
|
self.chunk_manager.release_chunk(chunk)
|
|
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
|
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
|
|
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
|
if p is not None:
|
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
|
rec_p = param_to_save_data[fp32_p]
|
|
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
|
|
|
# save all buffers
|
|
for name, buf in self.named_buffers():
|
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
|
# save extra states
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "get_extra_state",
|
|
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
|
destination[extra_state_key] = self.get_extra_state()
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into
|
|
this module and its descendants. If :attr:`strict` is ``True``, then
|
|
the keys of :attr:`state_dict` must exactly match the keys returned
|
|
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
strict (bool, optional): whether to strictly enforce that the keys
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
|
|
|
Returns:
|
|
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
|
* **missing_keys** is a list of str containing the missing keys
|
|
* **unexpected_keys** is a list of str containing the unexpected keys
|
|
|
|
Note:
|
|
If a parameter or buffer is registered as ``None`` and its corresponding key
|
|
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
|
``RuntimeError``.
|
|
"""
|
|
missing_keys: List[str] = []
|
|
unexpected_keys: List[str] = []
|
|
error_msgs: List[str] = []
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
# mypy isn't aware that "_metadata" exists in state_dict
|
|
state_dict._metadata = metadata # type: ignore[attr-defined]
|
|
|
|
prefix = ''
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
if strict:
|
|
if len(unexpected_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
|
'"{}"'.format(k) for k in unexpected_keys)))
|
|
if len(missing_keys) > 0:
|
|
error_msgs.insert(
|
|
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
|
|
|
if len(error_msgs) > 0:
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
self.__class__.__name__, "\n\t".join(error_msgs)))
|
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
|
error_msgs):
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
this module, but not its descendants. This is called on every submodule
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
.. note::
|
|
:attr:`state_dict` is not the same object as the input
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
it can be modified.
|
|
|
|
Args:
|
|
state_dict (dict): a dict containing parameters and
|
|
persistent buffers.
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
module
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
See
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
parameters and buffers in this module
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
this list
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
keys to this list
|
|
error_msgs (list of str): error messages should be added to this
|
|
list, and will be reported together in
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
|
"""
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
def load(name, dest_tensor, copy_func):
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
input_param = state_dict[key]
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
|
input_param = input_param[0]
|
|
if input_param.shape != dest_tensor.shape:
|
|
# local shape should match the one in checkpoint
|
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
|
'the shape in current model is {}.'.format(key, input_param.shape,
|
|
dest_tensor.shape))
|
|
return
|
|
try:
|
|
with torch.no_grad():
|
|
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
|
|
copy_func(input_param)
|
|
except Exception as ex:
|
|
error_msgs.append('While copying the parameter named "{}", '
|
|
'whose dimensions in the model are {} and '
|
|
'whose dimensions in the checkpoint are {}, '
|
|
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
|
|
ex.args))
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
|
|
def load_fp32_p(fp32_p, data):
|
|
if fp32_p.storage().size() > 0:
|
|
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
|
|
|
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
|
if p is not None:
|
|
load(name, fp32_p, partial(load_fp32_p, fp32_p))
|
|
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
|
|
|
for name, buf in persistent_buffers.items():
|
|
if buf is not None:
|
|
load(name, buf, buf.copy_)
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
if getattr(self.__class__, "set_extra_state",
|
|
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
|
if extra_state_key in state_dict:
|
|
self.set_extra_state(state_dict[extra_state_key])
|
|
elif strict:
|
|
missing_keys.append(extra_state_key)
|
|
elif strict and (extra_state_key in state_dict):
|
|
unexpected_keys.append(extra_state_key)
|
|
|
|
if strict:
|
|
for key in state_dict.keys():
|
|
if key.startswith(prefix) and key != extra_state_key:
|
|
input_name = key[len(prefix):]
|
|
if input_name not in local_state:
|
|
unexpected_keys.append(key)
|
|
|
|
def _cast_buffers(self):
|
|
for buffer in self.module.buffers():
|
|
buffer.data = buffer.cuda()
|
|
if torch.is_floating_point(buffer):
|
|
buffer.data = buffer.half()
|