ColossalAI/colossalai/tensor/colo_tensor.py

274 lines
11 KiB
Python

from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional
def _convert_output(output, pg: ProcessGroup):
if type(output) == torch.Tensor:
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, pg) for o in output)
else:
return output
def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
for elem in args:
if isinstance(elem, ColoTensor):
pg = elem.get_process_group()
return pg
elif isinstance(elem, (list, tuple)):
pg = _scan_for_pg_from_args(elem, {})
if pg is not None:
return pg
for k, v in kwargs:
if isinstance(v, ColoTensor):
pg = v.get_process_group()
return pg
return None
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
"""
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
if spec.pg is None:
self.process_group = ProcessGroup()
else:
self.process_group = spec.pg
self._type = TensorType.NONMODEL
self._graph_node = None
def has_compute_spec(self) -> bool:
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
Args:
pg (ProcessGroup): target pg
Raises:
RuntimeError:
RuntimeError:
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
if self.process_group.tp_world_size() != 1:
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
if self.dist_spec.placement.value != 'r':
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
self.process_group = pg
def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._convert_to_dist_spec(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
pg = _scan_for_pg_from_args(args, kwargs)
return _convert_output(ret, pg)
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
self.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
memo[id(self)] = tensor
return tensor
##### override builtin functions which must use tensor in replicate placement ####
def view_local(self, *args) -> 'ColoTensor':
return super().view(*args)
def size_local(self, *args, **kwargs) -> torch.Size:
return super().size(*args, **kwargs)
def view_global(self, *args) -> 'ColoTensor':
"""override the torch buildin view()
the args passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.is_replicate():
return super().view(*args)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
self.process_group)
self.dist_spec = distspec.replicate()
return super().view(*args)
def size_global(self, args: Optional[int] = None):
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.is_replicate():
if args is not None:
return super().size(args)
else:
return super().size()
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(super().size())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args is not None:
return size_list[args]
else:
return torch.Size(size_list)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0