[Tensor] add Parameter inheritance for ColoParameter (#1041)

* add Parameter inheritance for ColoParameter

* remove tricks

* remove tricks

* polish

* polish
pull/1015/head^2
Ziyue Jiang 2022-05-30 17:23:44 +08:00 committed by GitHub
parent 4d8a574cd3
commit 7c530b9de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 100 deletions

View File

@ -3,15 +3,15 @@ from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from typing import Optional
class ColoParameter(ColoTensor):
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __new__(cls,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
@ -19,7 +19,7 @@ class ColoParameter(ColoTensor):
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
@ -43,4 +43,30 @@ class ColoParameter(ColoTensor):
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor
def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param
# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError

View File

@ -24,96 +24,6 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError("cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoParameter, torch.nn.Parameter)):
if params is None:
raise AttributeError("cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)".format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)".format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)".format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = self.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `"
+ param_name + "`")
param = getattr(mod, param_name)
return param
def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
@ -134,10 +44,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
self._register_colo_modules()
def _register_colo_modules(self):

View File

@ -353,5 +353,5 @@ def _test_pretrain_load(world_size):
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
test_model(4)
# _test_pretrain_load(4)
# test_model(4)
_test_pretrain_load(4)

View File

@ -0,0 +1,26 @@
from colossalai.tensor import ColoParameter, ColoTensor
import torch
from numpy import allclose
from _utils import tensor_equal
def test_multiinheritance():
colo_param = ColoParameter()
assert isinstance(colo_param, ColoTensor)
assert isinstance(colo_param, torch.nn.Parameter)
# __deepcopy__ overload
import copy
colo_param2 = copy.deepcopy(colo_param)
assert isinstance(colo_param2, ColoParameter)
assert tensor_equal(colo_param.data, colo_param2.data)
assert colo_param.requires_grad == colo_param2.requires_grad
# __repr__ overload
assert 'ColoParameter' in str(colo_param)
# __torch_function__
clone_param = torch.clone(colo_param)
assert isinstance(clone_param, ColoTensor)
if __name__ == '__main__':
test_multiinheritance()

View File

@ -46,3 +46,4 @@ def test_operand():
t_ref_res = t_ref + t_ref
t_res = t + t
assert torch.allclose(t_ref_res, t_res)