mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] Add some attributes to ColoTensor (#877)
* [Tensor] add some function to ColoTensor * torch.allclose * rm torch.addpull/878/head
parent
425b4a96b8
commit
909211453b
|
@ -3,6 +3,19 @@ from colossalai.tensor.op_wrapper import colo_op_impl
|
|||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.allclose)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
a = args[0]
|
||||
b = args[1]
|
||||
|
||||
if isinstance(a, ColoTensor):
|
||||
a = a.torch_tensor()
|
||||
elif isinstance(b, ColoTensor):
|
||||
b = b.torch_tensor()
|
||||
|
||||
return torch.allclose(a, b, **kwargs)
|
||||
|
||||
|
||||
@colo_op_impl(torch.mean)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
input_t = args[0]
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
|
@ -37,6 +38,9 @@ class ColoTensor(object):
|
|||
self._torch_tensor = torch_tensor
|
||||
self._shard_spec = shard_spec
|
||||
|
||||
def __getitem__(self, key):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
|
||||
|
||||
@property
|
||||
def shard_spec(self) -> TensorSpec:
|
||||
return self._shard_spec
|
||||
|
@ -148,7 +152,10 @@ class ColoTensor(object):
|
|||
kwargs = {}
|
||||
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return func(*args, **kwargs)
|
||||
return ColoTensor.init_from_torch_tensor(func(*args, **kwargs))
|
||||
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
||||
def __add__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
|
|
|
@ -87,19 +87,11 @@ def test_no_wrap_op():
|
|||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||
|
||||
|
||||
def check_all():
|
||||
test_linear()
|
||||
test_element_wise()
|
||||
test_no_wrap_op()
|
||||
test_lazy_init_tensor()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_lazy_init_ptensor()
|
||||
test_layernorm()
|
||||
check_all()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from numpy import allclose
|
||||
|
||||
|
||||
def test_tensor_indexing():
|
||||
torch_t = torch.randn(2, 3)
|
||||
colo_t = ColoTensor.init_from_torch_tensor(torch_t)
|
||||
assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor())
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
Loading…
Reference in New Issue