mirror of https://github.com/hpcaitech/ColossalAI
[tensor] wrap function in the torch_tensor to ColoTensor (#881)
parent
4df6471f5d
commit
9bc5a77c31
|
@ -2,7 +2,7 @@ from colossalai.context import parallel_mode
|
|||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, Callable
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.utils import divide
|
||||
|
@ -152,26 +152,28 @@ class ColoTensor(object):
|
|||
kwargs = {}
|
||||
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return ColoTensor.init_from_torch_tensor(func(*args, **kwargs))
|
||||
return cls._filter_outputs_with_colo(func(*args,**kwargs))
|
||||
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
||||
## TODO(fjr) we reduce redundency of the following code
|
||||
def __add__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
def __getattr__(self, name):
|
||||
def replace_tensor_with_colo(func):
|
||||
def execute_func(*args, **kwargs):
|
||||
return self._filter_outputs_with_colo(func(*args,**kwargs))
|
||||
return execute_func
|
||||
|
||||
def __truediv__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
attr = getattr(self._torch_tensor, name)
|
||||
if isinstance(attr, Callable):
|
||||
return replace_tensor_with_colo(attr)
|
||||
else:
|
||||
return attr
|
||||
|
||||
def view(self, *args: int) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().view(*args))
|
||||
|
||||
def permute(self, *args) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().permute(*args))
|
||||
|
||||
def transpose(self, *args) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().transpose(*args))
|
||||
|
||||
def contiguous(self):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor().contiguous())
|
||||
@classmethod
|
||||
def _filter_outputs_with_colo(cls, outputs):
|
||||
if outputs is None: # return None
|
||||
return None
|
||||
elif type(outputs) is not tuple: # num of return val = 1
|
||||
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
|
||||
else: # num of return val > 1
|
||||
return tuple([ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs])
|
||||
|
|
|
@ -86,12 +86,32 @@ def test_no_wrap_op():
|
|||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
def test_wrapped_tensor_func():
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
|
||||
# non-func attr
|
||||
assert t.is_cuda == t_ref.is_cuda
|
||||
|
||||
# TODO I don't find out a tensor function which returns None.
|
||||
|
||||
# return 1 torch.Tensor
|
||||
t_abs = t.abs()
|
||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
|
||||
|
||||
# return 1 non-torch.Tensor
|
||||
assert t.dim() == t_ref.dim()
|
||||
|
||||
# return >1 torch.Tensor
|
||||
t_split1, t_split2 = t.split(2)
|
||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
|
||||
|
||||
|
||||
def check_all():
|
||||
test_linear()
|
||||
test_element_wise()
|
||||
test_no_wrap_op()
|
||||
|
||||
test_wrapped_tensor_func()
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
||||
|
|
Loading…
Reference in New Issue