Browse Source

[Tensor] Add some attributes to ColoTensor (#877)

* [Tensor] add some function to ColoTensor

* torch.allclose

* rm torch.add
pull/878/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
909211453b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      colossalai/tensor/_ops/element_wise.py
  2. 11
      colossalai/tensor/colo_tensor.py
  3. 10
      tests/test_tensor/test_op.py
  4. 15
      tests/test_tensor/test_tensor.py

13
colossalai/tensor/_ops/element_wise.py

@ -3,6 +3,19 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.allclose)
def colo_mean(types, args=(), kwargs=None, pg=None):
a = args[0]
b = args[1]
if isinstance(a, ColoTensor):
a = a.torch_tensor()
elif isinstance(b, ColoTensor):
b = b.torch_tensor()
return torch.allclose(a, b, **kwargs)
@colo_op_impl(torch.mean)
def colo_mean(types, args=(), kwargs=None, pg=None):
input_t = args[0]

11
colossalai/tensor/colo_tensor.py

@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
@ -37,6 +38,9 @@ class ColoTensor(object):
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
@property
def shard_spec(self) -> TensorSpec:
return self._shard_spec
@ -148,7 +152,10 @@ class ColoTensor(object):
kwargs = {}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)
return ColoTensor.init_from_torch_tensor(func(*args, **kwargs))
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())

10
tests/test_tensor/test_op.py

@ -87,19 +87,11 @@ def test_no_wrap_op():
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def check_all():
test_linear()
test_element_wise()
test_no_wrap_op()
test_lazy_init_tensor()
if __name__ == '__main__':
# test_lazy_init_ptensor()
test_layernorm()
check_all()

15
tests/test_tensor/test_tensor.py

@ -0,0 +1,15 @@
import torch
from colossalai.tensor import ColoTensor
from numpy import allclose
def test_tensor_indexing():
torch_t = torch.randn(2, 3)
colo_t = ColoTensor.init_from_torch_tensor(torch_t)
assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor())
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
Loading…
Cancel
Save