mirror of https://github.com/hpcaitech/ColossalAI
parent
45b9124df4
commit
4ca732349e
|
@ -232,16 +232,20 @@ class ColoTensor(object):
|
|||
def __add__(self, o) -> "ColoTensor":
|
||||
if isinstance(o, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
elif isinstance(o, torch.Tensor):
|
||||
elif isinstance(o, (torch.Tensor, int, float)):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
|
||||
else:
|
||||
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __truediv__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
def replace_tensor_with_colo(func):
|
||||
|
||||
def execute_func(*args, **kwargs):
|
||||
# transform the ColoTensor args to torch Tensor.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
|
@ -282,3 +286,13 @@ class ColoTensor(object):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
return dim
|
||||
|
||||
def __mul__(self, other) -> "ColoTensor":
|
||||
if isinstance(other, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
|
||||
elif isinstance(other, (torch.Tensor, int, float)):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
|
||||
else:
|
||||
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
|
Loading…
Reference in New Issue