2022-04-26 07:10:47 +00:00
|
|
|
import torch
|
2022-05-19 04:44:59 +00:00
|
|
|
import pytest
|
2022-04-26 07:10:47 +00:00
|
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
from numpy import allclose
|
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
import colossalai
|
|
|
|
from colossalai.utils import free_port
|
|
|
|
from colossalai.tensor import distspec, TensorSpec
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use
|
|
|
|
from colossalai.utils import free_port
|
2022-06-29 02:03:09 +00:00
|
|
|
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
2022-06-21 10:28:38 +00:00
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
from functools import partial
|
|
|
|
|
2022-04-26 07:10:47 +00:00
|
|
|
|
|
|
|
def test_tensor_indexing():
|
|
|
|
torch_t = torch.randn(2, 3)
|
2022-05-19 04:44:59 +00:00
|
|
|
colo_t = ColoTensor(torch_t)
|
|
|
|
assert allclose(torch_t[:, 1], colo_t[:, 1])
|
2022-04-26 07:10:47 +00:00
|
|
|
|
|
|
|
|
2022-04-27 02:57:49 +00:00
|
|
|
def test_wrapped_tensor_func():
|
|
|
|
t_ref = torch.randn(4, 5)
|
2022-05-19 04:44:59 +00:00
|
|
|
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
2022-04-27 02:57:49 +00:00
|
|
|
|
|
|
|
# non-func attr
|
|
|
|
assert t.is_cuda == t_ref.is_cuda
|
|
|
|
|
|
|
|
# return 1 torch.Tensor
|
|
|
|
t_abs = t.abs()
|
2022-05-19 04:44:59 +00:00
|
|
|
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
|
2022-04-27 02:57:49 +00:00
|
|
|
|
|
|
|
# return 1 non-torch.Tensor
|
|
|
|
assert t.dim() == t_ref.dim()
|
|
|
|
|
|
|
|
# return >1 torch.Tensor
|
|
|
|
t_split1, t_split2 = t.split(2)
|
|
|
|
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
|
|
|
|
|
|
|
|
|
|
|
|
def test_operand():
|
|
|
|
t_ref = torch.randn(4, 5)
|
2022-05-19 04:44:59 +00:00
|
|
|
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
2022-04-27 02:57:49 +00:00
|
|
|
|
|
|
|
t_ref_res = t_ref + t_ref
|
|
|
|
t_res = t + t
|
|
|
|
assert torch.allclose(t_ref_res, t_res)
|
2022-05-30 09:23:44 +00:00
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
|
|
|
|
#### Test Distributed init a Colotensor
|
|
|
|
|
|
|
|
|
2022-06-27 01:45:26 +00:00
|
|
|
def _run_view(world_size):
|
|
|
|
t_ref = torch.randn(4, 5)
|
2022-06-29 02:03:09 +00:00
|
|
|
rank = gpc.get_global_rank()
|
|
|
|
pg = ProcessGroup(rank, list(range(world_size)))
|
|
|
|
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
2022-06-27 01:45:26 +00:00
|
|
|
t = ColoTensor.from_torch_tensor(
|
2022-06-27 07:56:11 +00:00
|
|
|
t_ref,
|
2022-06-29 02:03:09 +00:00
|
|
|
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
2022-06-27 01:45:26 +00:00
|
|
|
|
2022-06-27 10:38:34 +00:00
|
|
|
assert t.size_global()[0] == 4 * world_size
|
|
|
|
assert t.size_global(1) == 5
|
|
|
|
assert t.size_global() == torch.Size([4 * world_size, 5])
|
2022-06-27 01:45:26 +00:00
|
|
|
|
2022-06-27 10:38:34 +00:00
|
|
|
t.view_local(4 * 5)
|
2022-06-27 07:56:11 +00:00
|
|
|
assert t.tensor_spec.dist_spec.placement.value == 's'
|
|
|
|
|
2022-06-27 10:38:34 +00:00
|
|
|
t = t.view_global(4 * 5 * world_size)
|
2022-06-27 07:56:11 +00:00
|
|
|
assert t.tensor_spec.dist_spec.placement.value == 'r'
|
2022-06-27 01:45:26 +00:00
|
|
|
assert t.shape == torch.Size([4 * 5 * world_size])
|
|
|
|
|
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
def _run_tensor_shard_init(world_size):
|
|
|
|
t_ref = torch.randn(4, 5)
|
2022-06-29 02:03:09 +00:00
|
|
|
|
|
|
|
rank = gpc.get_global_rank()
|
|
|
|
pg = ProcessGroup(rank, list(range(world_size)))
|
|
|
|
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
2022-06-21 10:28:38 +00:00
|
|
|
tensor_spec = TensorSpec(shard_spec)
|
|
|
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
2022-06-24 05:08:54 +00:00
|
|
|
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
2022-06-21 10:28:38 +00:00
|
|
|
assert t.shape == torch.Size((4 * world_size, 5))
|
|
|
|
|
|
|
|
|
|
|
|
def _run_tensor_replicated_init(world_size):
|
|
|
|
t_ref = torch.randn(4 * world_size, 5)
|
|
|
|
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
|
|
|
|
|
|
|
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
|
|
|
|
|
|
|
|
2022-06-27 01:45:26 +00:00
|
|
|
def run_dist_tests(rank, world_size, port):
|
2022-06-21 10:28:38 +00:00
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
|
|
_run_tensor_shard_init(world_size)
|
|
|
|
_run_tensor_replicated_init(world_size)
|
2022-06-27 01:45:26 +00:00
|
|
|
_run_view(world_size)
|
2022-06-21 10:28:38 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.parametrize('world_size', [1, 2])
|
|
|
|
@rerun_if_address_is_in_use()
|
2022-06-27 07:56:11 +00:00
|
|
|
def test_dist_cases(world_size):
|
2022-06-27 01:45:26 +00:00
|
|
|
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
2022-06-21 10:28:38 +00:00
|
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2022-06-27 07:56:11 +00:00
|
|
|
test_dist_cases(2)
|