ColossalAI/colossalai/tensor/colo_parameter.py

29 lines
956 B
Python
Raw Normal View History

2022-05-06 04:57:14 +00:00
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
class ColoParameter(ColoTensor):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
self._type = TensorType.MODEL
def __new__(cls, *args, **kwargs):
t = super(ColoParameter, cls).__new__(cls)
t._type = TensorType.MODEL
return t
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
colo_p = ColoParameter(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_p