mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] Add some attributes to ColoTensor (#877)
* [Tensor] add some function to ColoTensor * torch.allclose * rm torch.addpull/878/head
parent
425b4a96b8
commit
909211453b
|
@ -3,6 +3,19 @@ from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(torch.allclose)
|
||||||
|
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||||
|
a = args[0]
|
||||||
|
b = args[1]
|
||||||
|
|
||||||
|
if isinstance(a, ColoTensor):
|
||||||
|
a = a.torch_tensor()
|
||||||
|
elif isinstance(b, ColoTensor):
|
||||||
|
b = b.torch_tensor()
|
||||||
|
|
||||||
|
return torch.allclose(a, b, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.mean)
|
@colo_op_impl(torch.mean)
|
||||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||||
input_t = args[0]
|
input_t = args[0]
|
||||||
|
|
|
@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(object):
|
class ColoTensor(object):
|
||||||
""" Data Structure for Tensor in Colossal-AI
|
""" Data Structure for Tensor in Colossal-AI
|
||||||
1. It contains a torch.Tensor as an attribute.
|
1. It contains a torch.Tensor as an attribute.
|
||||||
|
@ -37,6 +38,9 @@ class ColoTensor(object):
|
||||||
self._torch_tensor = torch_tensor
|
self._torch_tensor = torch_tensor
|
||||||
self._shard_spec = shard_spec
|
self._shard_spec = shard_spec
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shard_spec(self) -> TensorSpec:
|
def shard_spec(self) -> TensorSpec:
|
||||||
return self._shard_spec
|
return self._shard_spec
|
||||||
|
@ -148,7 +152,10 @@ class ColoTensor(object):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||||
return func(*args, **kwargs)
|
return ColoTensor.init_from_torch_tensor(func(*args, **kwargs))
|
||||||
|
|
||||||
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
|
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||||
|
|
||||||
|
def __add__(self, o) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||||
|
|
|
@ -87,19 +87,11 @@ def test_no_wrap_op():
|
||||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||||
|
|
||||||
|
|
||||||
def test_lazy_init_tensor():
|
|
||||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
|
||||||
assert lazy_t._torch_tensor.numel() == 0
|
|
||||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
|
||||||
|
|
||||||
|
|
||||||
def check_all():
|
def check_all():
|
||||||
test_linear()
|
test_linear()
|
||||||
test_element_wise()
|
test_element_wise()
|
||||||
test_no_wrap_op()
|
test_no_wrap_op()
|
||||||
test_lazy_init_tensor()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_lazy_init_ptensor()
|
check_all()
|
||||||
test_layernorm()
|
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
from numpy import allclose
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_indexing():
|
||||||
|
torch_t = torch.randn(2, 3)
|
||||||
|
colo_t = ColoTensor.init_from_torch_tensor(torch_t)
|
||||||
|
assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor())
|
||||||
|
|
||||||
|
|
||||||
|
def test_lazy_init_tensor():
|
||||||
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||||
|
assert lazy_t._torch_tensor.numel() == 0
|
||||||
|
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
Loading…
Reference in New Issue