ColossalAI/colossalai/tensor/colo_tensor.py

78 lines
2.9 KiB
Python
Raw Normal View History

import torch
2022-04-21 06:15:48 +00:00
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
2022-04-21 06:15:48 +00:00
class ColoTensor(object):
2022-04-21 07:40:23 +00:00
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs):
2022-04-21 06:15:48 +00:00
return super(ColoTensor, cls).__new__(cls)
2022-04-21 07:40:23 +00:00
def __init__(
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
2022-04-22 09:07:46 +00:00
device=None,
torch_tensor=torch.empty(0),
2022-04-21 07:40:23 +00:00
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
2022-04-22 09:07:46 +00:00
self._device = device
2022-04-21 07:40:23 +00:00
self._torch_tensor = torch_tensor
2022-04-22 09:07:46 +00:00
def numel(self):
return sum(self._size)
2022-04-21 07:40:23 +00:00
@staticmethod
2022-04-22 09:07:46 +00:00
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
2022-04-21 07:40:23 +00:00
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
2022-04-22 09:07:46 +00:00
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
2022-04-21 07:40:23 +00:00
return colo_t
def del_torch_tensor(self) -> None:
self._size = (0,)
self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor:
2022-04-22 09:07:46 +00:00
if self._torch_tensor.numel() == 0:
2022-04-21 07:40:23 +00:00
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
2022-04-22 09:07:46 +00:00
pin_memory=self._pin_memory,
2022-04-21 07:40:23 +00:00
requires_grad=self._requires_grad,
2022-04-22 09:07:46 +00:00
device=self._device)
return self._torch_tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
2022-04-21 06:15:48 +00:00
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
for arg in args:
2022-04-21 06:15:48 +00:00
if isinstance(arg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
for kwarg in kwargs.values():
2022-04-21 06:15:48 +00:00
if isinstance(kwarg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
else:
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
2022-04-22 09:07:46 +00:00
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)