|
|
|
@ -1,6 +1,8 @@
|
|
|
|
|
from numpy import product |
|
|
|
|
import torch |
|
|
|
|
from .op_wrapper import _COLOSSAL_OPS |
|
|
|
|
from typing import Tuple |
|
|
|
|
import numpy |
|
|
|
|
from .op_wrapper import _COLOSSAL_OPS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ColoTensor(object): |
|
|
|
@ -31,7 +33,7 @@ class ColoTensor(object):
|
|
|
|
|
self._torch_tensor = torch_tensor |
|
|
|
|
|
|
|
|
|
def numel(self): |
|
|
|
|
return sum(self._size) |
|
|
|
|
return product(self._size) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': |
|
|
|
@ -44,9 +46,17 @@ class ColoTensor(object):
|
|
|
|
|
return colo_t |
|
|
|
|
|
|
|
|
|
def del_torch_tensor(self, save_shape=False) -> None: |
|
|
|
|
if save_shape: |
|
|
|
|
""" |
|
|
|
|
delete the payload of the torch tensor. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
save_shape (bool, optional): if saving the shape of the torch_tensor. |
|
|
|
|
If saving the shape, the size of self._torch_tensor is inconsist with the self._size. |
|
|
|
|
Defaults to False. |
|
|
|
|
""" |
|
|
|
|
if not save_shape: |
|
|
|
|
self._size = (0,) |
|
|
|
|
self._torch_tensor = torch.empty((0,)) |
|
|
|
|
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype) |
|
|
|
|
|
|
|
|
|
def torch_tensor(self) -> torch.Tensor: |
|
|
|
|
if self._torch_tensor.numel() == 0: |
|
|
|
|