|
|
|
from .op_wrapper import _COLOSSAL_OPS
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from typing import Tuple, Optional, Callable, Union
|
|
|
|
from numpy import product
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.nn.layer.utils import divide
|
|
|
|
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern
|
|
|
|
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward
|
|
|
|
from .const import TensorType
|
|
|
|
|
|
|
|
|
|
|
|
class ColoTensor(object):
|
|
|
|
""" Data Structure for Tensor in Colossal-AI
|
|
|
|
1. It contains a torch.Tensor as an attribute.
|
|
|
|
2. It supports lazy init the tensor's payload.
|
|
|
|
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
|
|
|
|
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
|
|
return super(ColoTensor, cls).__new__(cls)
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
*size: Tuple[int],
|
|
|
|
dtype=None,
|
|
|
|
requires_grad=False,
|
|
|
|
pin_memory=False,
|
|
|
|
device=None,
|
|
|
|
torch_tensor=torch.empty(0),
|
|
|
|
shard_spec: TensorSpec = TensorSpec()):
|
|
|
|
self._size = size
|
|
|
|
self._dtype = dtype
|
|
|
|
self._requires_grad = requires_grad
|
|
|
|
self._pin_memory = pin_memory
|
|
|
|
self._device = device
|
|
|
|
self._torch_tensor = torch_tensor
|
|
|
|
self._shard_spec = shard_spec
|
|
|
|
self._shard_pattern = ShardPattern.NA
|
|
|
|
self._type = TensorType.NONMODEL
|
|
|
|
self._graph_node = None
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shard_spec(self) -> TensorSpec:
|
|
|
|
return self._shard_spec
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shard_pattern(self):
|
|
|
|
return self._shard_pattern
|
|
|
|
|
|
|
|
@property
|
|
|
|
def data(self):
|
|
|
|
return self._torch_tensor.data
|
|
|
|
|
|
|
|
@data.setter
|
|
|
|
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
|
|
|
|
if isinstance(tensor, ColoTensor):
|
|
|
|
self._torch_tensor.data = tensor.data
|
|
|
|
elif isinstance(tensor, torch.Tensor):
|
|
|
|
self._torch_tensor.data = tensor
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@property
|
|
|
|
def grad(self):
|
|
|
|
return self._torch_tensor.grad
|
|
|
|
|
|
|
|
@property
|
|
|
|
def size(self):
|
|
|
|
return self._size
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
return torch.Size(self._size)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
return self._torch_tensor.device
|
|
|
|
|
|
|
|
def size(self, dim=None):
|
|
|
|
if dim is None:
|
|
|
|
return self.shape
|
|
|
|
return self._size[dim]
|
|
|
|
|
|
|
|
def dim(self):
|
|
|
|
return len(self._size)
|
|
|
|
|
|
|
|
def normal_(self, mean=0., std=1.):
|
|
|
|
torch_tensor = self.torch_tensor()
|
|
|
|
return torch_tensor.normal_(mean=mean, std=std)
|
|
|
|
|
|
|
|
def numel(self):
|
|
|
|
return product(self._size)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
|
|
|
colo_t = ColoTensor(*tensor.size(),
|
|
|
|
dtype=tensor.dtype,
|
|
|
|
requires_grad=tensor.requires_grad,
|
|
|
|
pin_memory=tensor.is_pinned(),
|
|
|
|
device=tensor.device,
|
|
|
|
torch_tensor=tensor if save_payload else torch.empty(0))
|
|
|
|
return colo_t
|
|
|
|
|
|
|
|
def del_torch_tensor(self, save_shape=False) -> None:
|
|
|
|
"""
|
|
|
|
delete the payload of the torch tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
save_shape (bool, optional): if saving the shape of the torch_tensor.
|
|
|
|
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
|
|
|
|
Defaults to False.
|
|
|
|
"""
|
|
|
|
if not save_shape:
|
|
|
|
self._size = (0,)
|
|
|
|
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
|
|
|
|
|
|
|
|
def torch_tensor(self) -> torch.Tensor:
|
|
|
|
if self._torch_tensor.numel() == 0:
|
|
|
|
self._torch_tensor = torch.empty(*self._size,
|
|
|
|
dtype=self._dtype,
|
|
|
|
pin_memory=self._pin_memory,
|
|
|
|
requires_grad=self._requires_grad,
|
|
|
|
device=self._device)
|
|
|
|
return self._torch_tensor
|
|
|
|
|
|
|
|
def set_spec(self, spec: TensorSpec, shard: bool = True) -> None:
|
|
|
|
self._shard_spec = spec
|
|
|
|
if shard == True:
|
|
|
|
self.shard()
|
|
|
|
|
|
|
|
def set_shard_pattern(self, shard_pattern: ShardPattern):
|
|
|
|
self._shard_pattern = shard_pattern
|
|
|
|
|
|
|
|
def shard(self):
|
|
|
|
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
|
|
|
if self._shard_pattern is not ShardPattern.NA: # reshard
|
|
|
|
self.gather()
|
|
|
|
# Model Parameters
|
|
|
|
if self._shard_spec.num_action == 1:
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
|
|
|
|
if parallel_action.compute_pattern in [
|
|
|
|
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
|
|
|
|
]:
|
|
|
|
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
|
|
|
# We bind our ComputePattern on weight, which has to be transposed when linear().
|
|
|
|
self._shard_pattern = ShardPattern.Col
|
|
|
|
elif parallel_action.compute_pattern in [
|
|
|
|
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
|
|
|
|
]:
|
|
|
|
self._shard_1d(parallel_action=parallel_action, dim=0)
|
|
|
|
self._shard_pattern = ShardPattern.Row
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def gather(self):
|
|
|
|
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
|
|
|
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
|
|
|
dim = self._get_gather_dim()
|
|
|
|
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
|
|
|
self._shard_pattern = ShardPattern.NA
|
|
|
|
self._size = self._torch_tensor.size()
|
|
|
|
|
|
|
|
def global_torch_tensor(self) -> torch.Tensor:
|
|
|
|
out_tensor = self.torch_tensor()
|
|
|
|
if self.is_gathered():
|
|
|
|
return out_tensor
|
|
|
|
|
|
|
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
|
|
|
world_size = gpc.get_world_size(parallel_action.parallel_mode)
|
|
|
|
if world_size == 1:
|
|
|
|
return out_tensor
|
|
|
|
|
|
|
|
rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
|
|
|
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
|
|
|
|
tensor_list[rank] = out_tensor
|
|
|
|
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
|
|
|
|
|
|
|
|
dim = self._get_gather_dim()
|
|
|
|
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
|
|
|
|
return out_tensor
|
|
|
|
|
|
|
|
def is_gathered(self) -> bool:
|
|
|
|
return self._shard_pattern == ShardPattern.NA
|
|
|
|
|
|
|
|
def has_spec(self) -> bool:
|
|
|
|
return self._shard_spec is not None and self._shard_spec.num_action > 0
|
|
|
|
|
|
|
|
def is_model_data(self) -> bool:
|
|
|
|
return self._type == TensorType.MODEL
|
|
|
|
|
|
|
|
def _shard_1d(self, parallel_action, dim=-1):
|
|
|
|
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
|
|
|
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
|
|
|
chunk_size = divide(self._size[dim], num_partition)
|
|
|
|
# Reshape to get shard for this rank and we don't want autograd
|
|
|
|
# recording here for the narrow op and 'local_shard' should be a
|
|
|
|
# leaf variable in the autograd graph.
|
|
|
|
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach().contiguous(
|
|
|
|
) # TODO Shall we clone() here since detach() will point to the old tensor?
|
|
|
|
self._torch_tensor.requires_grad = self._requires_grad
|
|
|
|
self._size = self._torch_tensor.size()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
|
|
global _COLOSSAL_OPS
|
|
|
|
if func in _COLOSSAL_OPS:
|
|
|
|
for arg in args:
|
|
|
|
if isinstance(arg, ColoTensor):
|
|
|
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
|
|
|
|
|
|
|
for kwarg in kwargs.values():
|
|
|
|
if isinstance(kwarg, ColoTensor):
|
|
|
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
|
|
|
else:
|
|
|
|
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
|
|
|
|
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
|
|
|
if kwargs is None:
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
|
|
|
return cls._filter_outputs_with_colo(func(*args, **kwargs))
|
|
|
|
|
|
|
|
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
|
|
|
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
|
|
|
|
|
|
|
def __add__(self, o) -> "ColoTensor":
|
|
|
|
if isinstance(o, ColoTensor):
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
|
|
|
elif isinstance(o, (torch.Tensor, int, float)):
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
|
|
|
|
else:
|
|
|
|
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
|
|
|
|
|
|
|
|
__radd__ = __add__
|
|
|
|
|
|
|
|
def __truediv__(self, o) -> "ColoTensor":
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
|
|
|
|
def replace_tensor_with_colo(func):
|
|
|
|
|
|
|
|
def execute_func(*args, **kwargs):
|
|
|
|
# transform the ColoTensor args to torch Tensor.
|
|
|
|
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
|
|
|
if kwargs is None:
|
|
|
|
kwargs = {}
|
|
|
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
|
|
|
return self._filter_outputs_with_colo(func(*args, **kwargs))
|
|
|
|
|
|
|
|
return execute_func
|
|
|
|
|
|
|
|
if hasattr(self._torch_tensor, name) == False:
|
|
|
|
raise AttributeError
|
|
|
|
|
|
|
|
attr = getattr(self._torch_tensor, name)
|
|
|
|
|
|
|
|
if isinstance(attr, Callable):
|
|
|
|
return replace_tensor_with_colo(attr)
|
|
|
|
else:
|
|
|
|
return attr
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _filter_outputs_with_colo(cls, outputs):
|
|
|
|
if outputs is None: # return None
|
|
|
|
return None
|
|
|
|
elif type(outputs) is not tuple: # num of return val = 1
|
|
|
|
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
|
|
|
|
else: # num of return val > 1
|
|
|
|
return tuple([
|
|
|
|
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
|
|
|
for output in outputs
|
|
|
|
])
|
|
|
|
|
|
|
|
def _get_gather_dim(self):
|
|
|
|
if self._shard_pattern == ShardPattern.Row:
|
|
|
|
dim = 0
|
|
|
|
elif self._shard_pattern == ShardPattern.Col:
|
|
|
|
dim = -1
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
return dim
|
|
|
|
|
|
|
|
def __mul__(self, other) -> "ColoTensor":
|
|
|
|
if isinstance(other, ColoTensor):
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
|
|
|
|
elif isinstance(other, (torch.Tensor, int, float)):
|
|
|
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
|
|
|
|
else:
|
|
|
|
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
|
|
|
|
|
|
|
|
__rmul__ = __mul__
|