You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/colo_parameter.py

36 lines
1.2 KiB

from .colo_tensor import ColoTensor
from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
class ColoParameter(ColoTensor):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __new__(cls,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor