ColossalAI/tests/test_tensor/test_dtensor.py

105 lines
4.0 KiB
Python

from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port
class TestModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, out_features)
self.linear_2 = torch.nn.Linear(out_features, in_features)
def forward(self, x):
x = self.linear_1(x)
x = self.linear_2(x)
return x
def check_dtensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_model = TestModel(8, 8).to('cuda')
original_tensor = torch.rand(4, 8).to('cuda')
compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec)
d_tensor = DTensor(original_tensor, layout)
assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype
if rank in (0, 1):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
elif rank in (2, 3):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
assert d_tensor.to_global().equal(original_tensor)
output = test_model(d_tensor)
if rank in (0, 1):
assert output.equal(compare_output.narrow(0, 0, 2))
elif rank in (2, 3):
assert output.equal(compare_output.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
def test_dtensor():
world_size = 4
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_dtensor()