mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] the bug of numel() in ColoTensor (#845)
parent
c1e8d2001e
commit
ea0a2ed25f
|
@ -1,6 +1,8 @@
|
||||||
|
from numpy import product
|
||||||
import torch
|
import torch
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
import numpy
|
||||||
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(object):
|
class ColoTensor(object):
|
||||||
|
@ -31,7 +33,7 @@ class ColoTensor(object):
|
||||||
self._torch_tensor = torch_tensor
|
self._torch_tensor = torch_tensor
|
||||||
|
|
||||||
def numel(self):
|
def numel(self):
|
||||||
return sum(self._size)
|
return product(self._size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||||
|
@ -44,9 +46,17 @@ class ColoTensor(object):
|
||||||
return colo_t
|
return colo_t
|
||||||
|
|
||||||
def del_torch_tensor(self, save_shape=False) -> None:
|
def del_torch_tensor(self, save_shape=False) -> None:
|
||||||
if save_shape:
|
"""
|
||||||
|
delete the payload of the torch tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_shape (bool, optional): if saving the shape of the torch_tensor.
|
||||||
|
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
if not save_shape:
|
||||||
self._size = (0,)
|
self._size = (0,)
|
||||||
self._torch_tensor = torch.empty((0,))
|
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
|
||||||
|
|
||||||
def torch_tensor(self) -> torch.Tensor:
|
def torch_tensor(self) -> torch.Tensor:
|
||||||
if self._torch_tensor.numel() == 0:
|
if self._torch_tensor.numel() == 0:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
def test_linear():
|
def test_linear():
|
||||||
in_dim = 4
|
in_dim = 4
|
||||||
out_dim = 5
|
out_dim = 5
|
||||||
|
@ -44,6 +45,7 @@ def test_linear():
|
||||||
# torch.nn.init.uniform_(t)
|
# torch.nn.init.uniform_(t)
|
||||||
# print(t)
|
# print(t)
|
||||||
|
|
||||||
|
|
||||||
def test_element_wise():
|
def test_element_wise():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
|
@ -59,10 +61,12 @@ def test_no_wrap_op():
|
||||||
assert torch.sum(t) == torch.sum(t_ref)
|
assert torch.sum(t) == torch.sum(t_ref)
|
||||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||||
|
|
||||||
|
|
||||||
def test_lazy_init_tensor():
|
def test_lazy_init_tensor():
|
||||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||||
assert lazy_t._torch_tensor.numel() == 0
|
assert lazy_t._torch_tensor.numel() == 0
|
||||||
assert lazy_t.torch_tensor().numel() == 6
|
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||||
|
|
||||||
|
|
||||||
def check_all():
|
def check_all():
|
||||||
test_linear()
|
test_linear()
|
||||||
|
@ -70,5 +74,6 @@ def check_all():
|
||||||
test_no_wrap_op()
|
test_no_wrap_op()
|
||||||
test_lazy_init_tensor()
|
test_lazy_init_tensor()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
check_all()
|
test_lazy_init_tensor()
|
||||||
|
|
Loading…
Reference in New Issue