diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 20ceef71b..8d7e96120 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,6 +1,8 @@ +from numpy import product import torch -from .op_wrapper import _COLOSSAL_OPS from typing import Tuple +import numpy +from .op_wrapper import _COLOSSAL_OPS class ColoTensor(object): @@ -31,7 +33,7 @@ class ColoTensor(object): self._torch_tensor = torch_tensor def numel(self): - return sum(self._size) + return product(self._size) @staticmethod def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': @@ -44,9 +46,17 @@ class ColoTensor(object): return colo_t def del_torch_tensor(self, save_shape=False) -> None: - if save_shape: + """ + delete the payload of the torch tensor. + + Args: + save_shape (bool, optional): if saving the shape of the torch_tensor. + If saving the shape, the size of self._torch_tensor is inconsist with the self._size. + Defaults to False. + """ + if not save_shape: self._size = (0,) - self._torch_tensor = torch.empty((0,)) + self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype) def torch_tensor(self) -> torch.Tensor: if self._torch_tensor.numel() == 0: diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 71ce01dd6..fd9febf01 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -3,6 +3,7 @@ import torch from colossalai.tensor import ColoTensor from copy import deepcopy + def test_linear(): in_dim = 4 out_dim = 5 @@ -44,6 +45,7 @@ def test_linear(): # torch.nn.init.uniform_(t) # print(t) + def test_element_wise(): t_ref = torch.randn(3, 5) t = ColoTensor.init_from_torch_tensor(t_ref.clone()) @@ -59,10 +61,12 @@ def test_no_wrap_op(): assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(input=t) == torch.sum(input=t_ref) + def test_lazy_init_tensor(): lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) assert lazy_t._torch_tensor.numel() == 0 - assert lazy_t.torch_tensor().numel() == 6 + assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel() + def check_all(): test_linear() @@ -70,5 +74,6 @@ def check_all(): test_no_wrap_op() test_lazy_init_tensor() + if __name__ == '__main__': - check_all() + test_lazy_init_tensor()