mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] make ColoTensor more robust for getattr (#886)
* [Tensor] make ColoTensor more robust for getattr * polish * polishpull/890/head
parent
9bc5a77c31
commit
72cdc06875
|
@ -12,7 +12,8 @@ def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||||
a = a.torch_tensor()
|
a = a.torch_tensor()
|
||||||
elif isinstance(b, ColoTensor):
|
elif isinstance(b, ColoTensor):
|
||||||
b = b.torch_tensor()
|
b = b.torch_tensor()
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
return torch.allclose(a, b, **kwargs)
|
return torch.allclose(a, b, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -152,18 +152,34 @@ class ColoTensor(object):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||||
return cls._filter_outputs_with_colo(func(*args,**kwargs))
|
return cls._filter_outputs_with_colo(func(*args, **kwargs))
|
||||||
|
|
||||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||||
|
|
||||||
|
def __add__(self, o) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||||
|
|
||||||
|
def __truediv__(self, o) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
|
||||||
def replace_tensor_with_colo(func):
|
def replace_tensor_with_colo(func):
|
||||||
|
|
||||||
def execute_func(*args, **kwargs):
|
def execute_func(*args, **kwargs):
|
||||||
return self._filter_outputs_with_colo(func(*args,**kwargs))
|
# transform the ColoTensor args to torch Tensor.
|
||||||
|
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||||
|
return self._filter_outputs_with_colo(func(*args, **kwargs))
|
||||||
|
|
||||||
return execute_func
|
return execute_func
|
||||||
|
|
||||||
|
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
|
||||||
attr = getattr(self._torch_tensor, name)
|
attr = getattr(self._torch_tensor, name)
|
||||||
|
|
||||||
if isinstance(attr, Callable):
|
if isinstance(attr, Callable):
|
||||||
return replace_tensor_with_colo(attr)
|
return replace_tensor_with_colo(attr)
|
||||||
else:
|
else:
|
||||||
|
@ -176,4 +192,7 @@ class ColoTensor(object):
|
||||||
elif type(outputs) is not tuple: # num of return val = 1
|
elif type(outputs) is not tuple: # num of return val = 1
|
||||||
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
|
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
|
||||||
else: # num of return val > 1
|
else: # num of return val > 1
|
||||||
return tuple([ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs])
|
return tuple([
|
||||||
|
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
||||||
|
for output in outputs
|
||||||
|
])
|
||||||
|
|
|
@ -86,32 +86,12 @@ def test_no_wrap_op():
|
||||||
assert torch.sum(t) == torch.sum(t_ref)
|
assert torch.sum(t) == torch.sum(t_ref)
|
||||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||||
|
|
||||||
def test_wrapped_tensor_func():
|
|
||||||
t_ref = torch.randn(4, 5)
|
|
||||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
|
||||||
|
|
||||||
# non-func attr
|
|
||||||
assert t.is_cuda == t_ref.is_cuda
|
|
||||||
|
|
||||||
# TODO I don't find out a tensor function which returns None.
|
|
||||||
|
|
||||||
# return 1 torch.Tensor
|
|
||||||
t_abs = t.abs()
|
|
||||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
|
|
||||||
|
|
||||||
# return 1 non-torch.Tensor
|
|
||||||
assert t.dim() == t_ref.dim()
|
|
||||||
|
|
||||||
# return >1 torch.Tensor
|
|
||||||
t_split1, t_split2 = t.split(2)
|
|
||||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
|
|
||||||
|
|
||||||
|
|
||||||
def check_all():
|
def check_all():
|
||||||
test_linear()
|
test_linear()
|
||||||
test_element_wise()
|
test_element_wise()
|
||||||
test_no_wrap_op()
|
test_no_wrap_op()
|
||||||
test_wrapped_tensor_func()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
check_all()
|
check_all()
|
||||||
|
|
|
@ -13,3 +13,33 @@ def test_lazy_init_tensor():
|
||||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||||
assert lazy_t._torch_tensor.numel() == 0
|
assert lazy_t._torch_tensor.numel() == 0
|
||||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapped_tensor_func():
|
||||||
|
t_ref = torch.randn(4, 5)
|
||||||
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
|
|
||||||
|
# non-func attr
|
||||||
|
assert t.is_cuda == t_ref.is_cuda
|
||||||
|
|
||||||
|
# TODO I don't find out a tensor function which returns None.
|
||||||
|
|
||||||
|
# return 1 torch.Tensor
|
||||||
|
t_abs = t.abs()
|
||||||
|
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
|
||||||
|
|
||||||
|
# return 1 non-torch.Tensor
|
||||||
|
assert t.dim() == t_ref.dim()
|
||||||
|
|
||||||
|
# return >1 torch.Tensor
|
||||||
|
t_split1, t_split2 = t.split(2)
|
||||||
|
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_operand():
|
||||||
|
t_ref = torch.randn(4, 5)
|
||||||
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
|
|
||||||
|
t_ref_res = t_ref + t_ref
|
||||||
|
t_res = t + t
|
||||||
|
assert torch.allclose(t_ref_res, t_res)
|
||||||
|
|
Loading…
Reference in New Issue