mirror of https://github.com/hpcaitech/ColossalAI
[ColoTensor] improves init functions. (#1150)
parent
8106d7b8c7
commit
8cdce0399c
|
@ -35,7 +35,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._tensor_spec = copy(spec)
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
from colossalai.tensor import TensorSpec
|
||||
from .const import TensorType
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
|
||||
def _convert_output(output):
|
||||
|
@ -18,34 +19,54 @@ def _convert_output(output):
|
|||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
2. It supports lazy init the tensor's payload.
|
||||
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
|
||||
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = TensorSpec(shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
"""__new__
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
|
||||
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._tensor_spec = copy(spec)
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
|
||||
@property
|
||||
def spec(self) -> TensorSpec:
|
||||
return self._spec
|
||||
return self._tensor_spec
|
||||
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self.convert_to_dist_spec_(spec.dist_spec)
|
||||
self._spec = spec
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
return self._spec.parallel_action is not None
|
||||
return self._tensor_spec.parallel_action is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
|
|||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
|
||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
self._spec.dist_spec = dist_spec
|
||||
self._tensor_spec.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
spec = copy(self._spec)
|
||||
spec.dist_spec = dist_spec
|
||||
tensor_spec = copy(self._tensor_spec)
|
||||
tensor_spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, spec)
|
||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
|
|
|
@ -4,6 +4,7 @@ from numpy import prod
|
|||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
|
@ -56,6 +57,12 @@ class DistSpecManager:
|
|||
|
||||
@staticmethod
|
||||
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||
# Therefore, we transfer tensor to GPU before gather.
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
|
@ -66,6 +73,9 @@ class DistSpecManager:
|
|||
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
|
||||
buffer = new_buffer
|
||||
assert len(buffer) == 1
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
buffer[0].data = buffer[0].data.to(saved_dev)
|
||||
return buffer[0]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -24,28 +24,13 @@ class ParallelAction(object):
|
|||
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
It contains two aspects of information:
|
||||
First, How are tensors distributed in Heterougenous memory space.
|
||||
Second, if the tensor is a model parameter, the Spec contains the
|
||||
parallel computation pattern of the Operator (Layer).
|
||||
We have to consider the hybrid parallel mode.
|
||||
The specification of the ColoTensor.
|
||||
Args:
|
||||
dist_spec (_DistSpec): descriping the layout among processes.
|
||||
parallel_action (Optional[ParallelAction], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
# a list of parallel actions.
|
||||
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
|
||||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||
# parallel_action_list = [
|
||||
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
||||
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# ]
|
||||
# When the ColoTensor is initialized,
|
||||
# we first splitting tensor according to ParallelAction of ZeRO,
|
||||
# then splitting tensor according to ParallelAction of TP1D_Linear.
|
||||
# During Linear computation
|
||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||
# We perform Linear Op according to compute pattern of TP1D_Linear.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
|
||||
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
|
||||
self.parallel_action = parallel_action
|
||||
self.dist_spec = dist_spec
|
||||
|
|
|
@ -3,6 +3,17 @@ import pytest
|
|||
from colossalai.tensor import ColoTensor
|
||||
from numpy import allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
|
||||
def test_tensor_indexing():
|
||||
torch_t = torch.randn(2, 3)
|
||||
|
@ -25,8 +36,6 @@ def test_wrapped_tensor_func():
|
|||
# non-func attr
|
||||
assert t.is_cuda == t_ref.is_cuda
|
||||
|
||||
# TODO I don't find out a tensor function which returns None.
|
||||
|
||||
# return 1 torch.Tensor
|
||||
t_abs = t.abs()
|
||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
|
||||
|
@ -47,3 +56,41 @@ def test_operand():
|
|||
t_res = t + t
|
||||
assert torch.allclose(t_ref_res, t_res)
|
||||
|
||||
|
||||
#### Test Distributed init a Colotensor
|
||||
|
||||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
print(gpc.get_group(ParallelMode.DATA).size())
|
||||
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
t_ref = torch.randn(4 * world_size, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def run_tensor_init(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def _test_dist_init(world_size):
|
||||
run_func = partial(run_tensor_init, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# _test_dist_init(4)
|
||||
test_new()
|
||||
|
|
Loading…
Reference in New Issue