mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
963 B
30 lines
963 B
3 years ago
|
import torch
|
||
3 years ago
|
from .op_wrapper import _COLOSSAL_OPS
|
||
3 years ago
|
|
||
|
|
||
3 years ago
|
class ColoTensor(object):
|
||
3 years ago
|
|
||
|
def __new__(cls, *args, **kwargs):
|
||
3 years ago
|
return super(ColoTensor, cls).__new__(cls)
|
||
3 years ago
|
|
||
|
def __init__(self, t: torch.Tensor) -> None:
|
||
|
self._torch_tensor = t
|
||
|
|
||
|
def torch_tensor(self) -> torch.Tensor:
|
||
|
return self._torch_tensor
|
||
|
|
||
|
@classmethod
|
||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||
3 years ago
|
global _COLOSSAL_OPS
|
||
|
if func in _COLOSSAL_OPS:
|
||
3 years ago
|
for arg in args:
|
||
3 years ago
|
if isinstance(arg, ColoTensor):
|
||
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||
3 years ago
|
|
||
|
for kwarg in kwargs.values():
|
||
3 years ago
|
if isinstance(kwarg, ColoTensor):
|
||
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||
3 years ago
|
|
||
|
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
||
3 years ago
|
f"kwargs: {kwargs} not supported for ColoTensor!")
|