mirror of https://github.com/hpcaitech/ColossalAI
31 lines
1.0 KiB
Python
31 lines
1.0 KiB
Python
import torch
|
|
from .api import _STATEFUL_OPS
|
|
|
|
|
|
class StatefulTensorV2(object):
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
return super(StatefulTensorV2, cls).__new__(cls)
|
|
|
|
def __init__(self, t: torch.Tensor) -> None:
|
|
self._torch_tensor = t
|
|
|
|
def torch_tensor(self) -> torch.Tensor:
|
|
return self._torch_tensor
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
global _STATEFUL_OPS
|
|
if func in _STATEFUL_OPS:
|
|
# Find StatefulTensorV2 instance to get process_group.
|
|
for arg in args:
|
|
if isinstance(arg, StatefulTensorV2):
|
|
return _STATEFUL_OPS[func](types, args, kwargs, None)
|
|
|
|
for kwarg in kwargs.values():
|
|
if isinstance(kwarg, StatefulTensorV2):
|
|
return _STATEFUL_OPS[func](types, args, kwargs, None)
|
|
|
|
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
|
f"kwargs: {kwargs} not supported for StatefulTensorV2!")
|