From 4ca732349ebd5e7dc65dd926eb2c1b654b3d43d7 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 10 May 2022 16:04:08 +0800 Subject: [PATCH] [tensor] colo tensor overrides mul (#927) * colo tensor overrides mul * polish code --- colossalai/tensor/colo_tensor.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index a0cb1bfe2..7b54b2e7f 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -232,16 +232,20 @@ class ColoTensor(object): def __add__(self, o) -> "ColoTensor": if isinstance(o, ColoTensor): return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor()) - elif isinstance(o, torch.Tensor): + elif isinstance(o, (torch.Tensor, int, float)): return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o) else: raise TypeError(f'{type(o)} is not supported in ColoTensor __add__') + __radd__ = __add__ + def __truediv__(self, o) -> "ColoTensor": return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) def __getattr__(self, name): + def replace_tensor_with_colo(func): + def execute_func(*args, **kwargs): # transform the ColoTensor args to torch Tensor. args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] @@ -282,3 +286,13 @@ class ColoTensor(object): else: raise NotImplementedError return dim + + def __mul__(self, other) -> "ColoTensor": + if isinstance(other, ColoTensor): + return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor()) + elif isinstance(other, (torch.Tensor, int, float)): + return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other) + else: + raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__') + + __rmul__ = __mul__