mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
956 B
29 lines
956 B
3 years ago
|
from .colo_tensor import ColoTensor
|
||
|
from .const import TensorType
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class ColoParameter(ColoTensor):
|
||
|
r"""A kind of ColoTensor to be considered as a module parameter.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *args, **kargs):
|
||
|
super().__init__(*args, **kargs)
|
||
|
self._type = TensorType.MODEL
|
||
|
|
||
|
def __new__(cls, *args, **kwargs):
|
||
|
t = super(ColoParameter, cls).__new__(cls)
|
||
|
t._type = TensorType.MODEL
|
||
|
return t
|
||
|
|
||
|
@staticmethod
|
||
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
|
||
|
colo_p = ColoParameter(*tensor.size(),
|
||
|
dtype=tensor.dtype,
|
||
|
requires_grad=tensor.requires_grad,
|
||
|
pin_memory=tensor.is_pinned(),
|
||
|
device=tensor.device,
|
||
|
torch_tensor=tensor if save_payload else torch.empty(0))
|
||
|
return colo_p
|