Browse Source

[hotfix] the bug of numel() in ColoTensor (#845)

pull/843/head^2
Jiarui Fang 3 years ago committed by GitHub
parent
commit
ea0a2ed25f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 18
      colossalai/tensor/colo_tensor.py
  2. 9
      tests/test_tensor/test_op.py

18
colossalai/tensor/colo_tensor.py

@ -1,6 +1,8 @@
from numpy import product
import torch import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple from typing import Tuple
import numpy
from .op_wrapper import _COLOSSAL_OPS
class ColoTensor(object): class ColoTensor(object):
@ -31,7 +33,7 @@ class ColoTensor(object):
self._torch_tensor = torch_tensor self._torch_tensor = torch_tensor
def numel(self): def numel(self):
return sum(self._size) return product(self._size)
@staticmethod @staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
@ -44,9 +46,17 @@ class ColoTensor(object):
return colo_t return colo_t
def del_torch_tensor(self, save_shape=False) -> None: def del_torch_tensor(self, save_shape=False) -> None:
if save_shape: """
delete the payload of the torch tensor.
Args:
save_shape (bool, optional): if saving the shape of the torch_tensor.
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
Defaults to False.
"""
if not save_shape:
self._size = (0,) self._size = (0,)
self._torch_tensor = torch.empty((0,)) self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
def torch_tensor(self) -> torch.Tensor: def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor.numel() == 0: if self._torch_tensor.numel() == 0:

9
tests/test_tensor/test_op.py

@ -3,6 +3,7 @@ import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from copy import deepcopy from copy import deepcopy
def test_linear(): def test_linear():
in_dim = 4 in_dim = 4
out_dim = 5 out_dim = 5
@ -44,6 +45,7 @@ def test_linear():
# torch.nn.init.uniform_(t) # torch.nn.init.uniform_(t)
# print(t) # print(t)
def test_element_wise(): def test_element_wise():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
@ -59,10 +61,12 @@ def test_no_wrap_op():
assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref) assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor(): def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0 assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.torch_tensor().numel() == 6 assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def check_all(): def check_all():
test_linear() test_linear()
@ -70,5 +74,6 @@ def check_all():
test_no_wrap_op() test_no_wrap_op()
test_lazy_init_tensor() test_lazy_init_tensor()
if __name__ == '__main__': if __name__ == '__main__':
check_all() test_lazy_init_tensor()

Loading…
Cancel
Save