[tensor] colo tensor overrides mul (#927)

* colo tensor overrides mul

* polish code
pull/929/head
ver217 2022-05-10 16:04:08 +08:00 committed by GitHub
parent 45b9124df4
commit 4ca732349e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 1 deletions

View File

@ -232,16 +232,20 @@ class ColoTensor(object):
def __add__(self, o) -> "ColoTensor":
if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, torch.Tensor):
elif isinstance(o, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
__radd__ = __add__
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
@ -282,3 +286,13 @@ class ColoTensor(object):
else:
raise NotImplementedError
return dim
def __mul__(self, other) -> "ColoTensor":
if isinstance(other, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
elif isinstance(other, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
else:
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
__rmul__ = __mul__