mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] the bug of numel() in ColoTensor (#845)
parent
c1e8d2001e
commit
ea0a2ed25f
|
@ -1,6 +1,8 @@
|
|||
from numpy import product
|
||||
import torch
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from typing import Tuple
|
||||
import numpy
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
|
@ -31,7 +33,7 @@ class ColoTensor(object):
|
|||
self._torch_tensor = torch_tensor
|
||||
|
||||
def numel(self):
|
||||
return sum(self._size)
|
||||
return product(self._size)
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||
|
@ -44,9 +46,17 @@ class ColoTensor(object):
|
|||
return colo_t
|
||||
|
||||
def del_torch_tensor(self, save_shape=False) -> None:
|
||||
if save_shape:
|
||||
"""
|
||||
delete the payload of the torch tensor.
|
||||
|
||||
Args:
|
||||
save_shape (bool, optional): if saving the shape of the torch_tensor.
|
||||
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
|
||||
Defaults to False.
|
||||
"""
|
||||
if not save_shape:
|
||||
self._size = (0,)
|
||||
self._torch_tensor = torch.empty((0,))
|
||||
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor.numel() == 0:
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def test_linear():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
|
@ -44,6 +45,7 @@ def test_linear():
|
|||
# torch.nn.init.uniform_(t)
|
||||
# print(t)
|
||||
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
|
@ -59,10 +61,12 @@ def test_no_wrap_op():
|
|||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||
|
||||
|
||||
def check_all():
|
||||
test_linear()
|
||||
|
@ -70,5 +74,6 @@ def check_all():
|
|||
test_no_wrap_op()
|
||||
test_lazy_init_tensor()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
||||
test_lazy_init_tensor()
|
||||
|
|
Loading…
Reference in New Issue