ColossalAI/colossalai/tensor/colo_parameter.py

39 lines
1.3 KiB
Python
Raw Normal View History

2022-05-06 04:57:14 +00:00
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
2022-05-06 04:57:14 +00:00
class ColoParameter(ColoTensor):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __new__(cls,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
2022-05-06 04:57:14 +00:00
def __init__(self,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
2022-05-06 04:57:14 +00:00
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'