mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add Parameter inheritance for ColoParameter (#1041)
* add Parameter inheritance for ColoParameter * remove tricks * remove tricks * polish * polishpull/1015/head^2
parent
4d8a574cd3
commit
7c530b9de2
|
@ -3,15 +3,15 @@ from .const import TensorType
|
|||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor):
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
data: torch.Tensor,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
if data is None:
|
||||
|
@ -19,7 +19,7 @@ class ColoParameter(ColoTensor):
|
|||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
|
@ -43,4 +43,30 @@ class ColoParameter(ColoTensor):
|
|||
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
# Adapted from torch._utils._rebuild_parameter
|
||||
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
|
||||
# colo_param = ColoParameter(data, requires_grad)
|
||||
# colo_param._backward_hooks = backward_hooks
|
||||
# return colo_param
|
||||
|
||||
# return (
|
||||
# _rebuild_colo_parameter,
|
||||
# (self.data, self.requires_grad, OrderedDict())
|
||||
# )
|
||||
|
||||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -24,96 +24,6 @@ def _named_params_with_replica(
|
|||
name = mod_prefix + ('.' if mod_prefix else '') + name
|
||||
yield name, val
|
||||
|
||||
|
||||
# Adapted from torch.nn.module.Module.register_param
|
||||
|
||||
|
||||
def _register_parameter_with_colotensor(self, name: str, param):
|
||||
if '_parameters' not in self.__dict__:
|
||||
raise AttributeError("cannot assign parameter before Module.__init__() call")
|
||||
|
||||
if not isinstance(name, torch._six.string_classes):
|
||||
raise TypeError("parameter name should be a string. "
|
||||
"Got {}".format(torch.typename(name)))
|
||||
if '.' in name:
|
||||
raise KeyError("parameter name can't contain \".\"")
|
||||
if name == '':
|
||||
raise KeyError("parameter name can't be empty string \"\"")
|
||||
if hasattr(self, name) and name not in self._parameters:
|
||||
raise KeyError("attribute '{}' already exists".format(name))
|
||||
|
||||
if param is None:
|
||||
self._parameters[name] = None
|
||||
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
|
||||
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
||||
"(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name))
|
||||
elif param.grad_fn:
|
||||
raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model "
|
||||
"parameters must be created explicitly. To express '{0}' "
|
||||
"as a function of another Tensor, compute the value in "
|
||||
"the forward() method.".format(name))
|
||||
else:
|
||||
self._parameters[name] = param
|
||||
|
||||
|
||||
# Adapted from torch.nn.module.Module.__setattr__
|
||||
|
||||
|
||||
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
|
||||
|
||||
def remove_from(*dicts_or_sets):
|
||||
for d in dicts_or_sets:
|
||||
if name in d:
|
||||
if isinstance(d, dict):
|
||||
del d[name]
|
||||
else:
|
||||
d.discard(name)
|
||||
|
||||
params = self.__dict__.get('_parameters')
|
||||
if isinstance(value, (ColoParameter, torch.nn.Parameter)):
|
||||
if params is None:
|
||||
raise AttributeError("cannot assign parameters before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
|
||||
self.register_parameter(name, value)
|
||||
elif params is not None and name in params:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as parameter '{}' "
|
||||
"(torch.nn.Parameter or None expected)".format(torch.typename(value), name))
|
||||
self.register_parameter(name, value)
|
||||
else:
|
||||
modules = self.__dict__.get('_modules')
|
||||
if isinstance(value, torch.nn.Module):
|
||||
if modules is None:
|
||||
raise AttributeError("cannot assign module before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
|
||||
modules[name] = value
|
||||
elif modules is not None and name in modules:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as child module '{}' "
|
||||
"(torch.nn.Module or None expected)".format(torch.typename(value), name))
|
||||
modules[name] = value
|
||||
else:
|
||||
buffers = self.__dict__.get('_buffers')
|
||||
if buffers is not None and name in buffers:
|
||||
if value is not None and not isinstance(value, torch.Tensor):
|
||||
raise TypeError("cannot assign '{}' as buffer '{}' "
|
||||
"(torch.Tensor or None expected)".format(torch.typename(value), name))
|
||||
buffers[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
|
||||
module_path, _, param_name = target.rpartition(".")
|
||||
|
||||
mod: torch.nn.Module = self.get_submodule(module_path)
|
||||
|
||||
if not hasattr(mod, param_name):
|
||||
raise AttributeError(mod._get_name() + " has no attribute `"
|
||||
+ param_name + "`")
|
||||
|
||||
param = getattr(mod, param_name)
|
||||
return param
|
||||
|
||||
def ColoModulize(module):
|
||||
"""
|
||||
Replacing the parameters() and named_parameters() with our customized ones
|
||||
|
@ -134,10 +44,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self._lazy_memory_allocate = lazy_memory_allocate
|
||||
self._device = device
|
||||
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
||||
|
||||
self._register_colo_modules()
|
||||
|
||||
def _register_colo_modules(self):
|
||||
|
|
|
@ -353,5 +353,5 @@ def _test_pretrain_load(world_size):
|
|||
if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
test_model(4)
|
||||
# _test_pretrain_load(4)
|
||||
# test_model(4)
|
||||
_test_pretrain_load(4)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
import torch
|
||||
from numpy import allclose
|
||||
from _utils import tensor_equal
|
||||
|
||||
def test_multiinheritance():
|
||||
colo_param = ColoParameter()
|
||||
assert isinstance(colo_param, ColoTensor)
|
||||
assert isinstance(colo_param, torch.nn.Parameter)
|
||||
|
||||
# __deepcopy__ overload
|
||||
import copy
|
||||
colo_param2 = copy.deepcopy(colo_param)
|
||||
assert isinstance(colo_param2, ColoParameter)
|
||||
assert tensor_equal(colo_param.data, colo_param2.data)
|
||||
assert colo_param.requires_grad == colo_param2.requires_grad
|
||||
|
||||
# __repr__ overload
|
||||
assert 'ColoParameter' in str(colo_param)
|
||||
|
||||
# __torch_function__
|
||||
clone_param = torch.clone(colo_param)
|
||||
assert isinstance(clone_param, ColoTensor)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_multiinheritance()
|
|
@ -46,3 +46,4 @@ def test_operand():
|
|||
t_ref_res = t_ref + t_ref
|
||||
t_res = t + t
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
|
|
Loading…
Reference in New Issue