You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/colo_parameter.py

102 lines
3.8 KiB

from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional
def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def replace_args(args, kwargs, new_args):
args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v
return tuple(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@property
def shared_param_modules(self):
return self._shared_param_modules
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor
def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param
# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError