diff --git a/colossalai/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py index 0bc932ec3..98f449188 100644 --- a/colossalai/tensor/_ops/element_wise.py +++ b/colossalai/tensor/_ops/element_wise.py @@ -3,6 +3,19 @@ from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor +@colo_op_impl(torch.allclose) +def colo_mean(types, args=(), kwargs=None, pg=None): + a = args[0] + b = args[1] + + if isinstance(a, ColoTensor): + a = a.torch_tensor() + elif isinstance(b, ColoTensor): + b = b.torch_tensor() + + return torch.allclose(a, b, **kwargs) + + @colo_op_impl(torch.mean) def colo_mean(types, args=(), kwargs=None, pg=None): input_t = args[0] diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 1824f0b49..477d02114 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import divide from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction + class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI 1. It contains a torch.Tensor as an attribute. @@ -37,6 +38,9 @@ class ColoTensor(object): self._torch_tensor = torch_tensor self._shard_spec = shard_spec + def __getitem__(self, key): + return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) + @property def shard_spec(self) -> TensorSpec: return self._shard_spec @@ -148,7 +152,10 @@ class ColoTensor(object): kwargs = {} kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} - return func(*args, **kwargs) + return ColoTensor.init_from_torch_tensor(func(*args, **kwargs)) - def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False): + def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False): self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph) + + def __add__(self, o) -> "ColoTensor": + return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor()) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 40ecc10fe..4babb73cd 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -87,19 +87,11 @@ def test_no_wrap_op(): assert torch.sum(input=t) == torch.sum(input=t_ref) -def test_lazy_init_tensor(): - lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) - assert lazy_t._torch_tensor.numel() == 0 - assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel() - - def check_all(): test_linear() test_element_wise() test_no_wrap_op() - test_lazy_init_tensor() if __name__ == '__main__': - # test_lazy_init_ptensor() - test_layernorm() + check_all() diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py new file mode 100644 index 000000000..c7216a9f8 --- /dev/null +++ b/tests/test_tensor/test_tensor.py @@ -0,0 +1,15 @@ +import torch +from colossalai.tensor import ColoTensor +from numpy import allclose + + +def test_tensor_indexing(): + torch_t = torch.randn(2, 3) + colo_t = ColoTensor.init_from_torch_tensor(torch_t) + assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor()) + + +def test_lazy_init_tensor(): + lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) + assert lazy_t._torch_tensor.numel() == 0 + assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()