mirror of https://github.com/hpcaitech/ColossalAI
105 lines
4.0 KiB
Python
105 lines
4.0 KiB
Python
from functools import partial
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
from colossalai.fx.tracer import ColoTracer
|
|
from colossalai.initialize import launch
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
|
from colossalai.tensor.d_tensor.layout import Layout
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
from colossalai.utils import free_port
|
|
|
|
|
|
class TestModel(torch.nn.Module):
|
|
|
|
def __init__(self, in_features, out_features):
|
|
super().__init__()
|
|
self.linear_1 = torch.nn.Linear(in_features, out_features)
|
|
self.linear_2 = torch.nn.Linear(out_features, in_features)
|
|
|
|
def forward(self, x):
|
|
x = self.linear_1(x)
|
|
x = self.linear_2(x)
|
|
return x
|
|
|
|
|
|
def check_dtensor(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
test_model = TestModel(8, 8).to('cuda')
|
|
original_tensor = torch.rand(4, 8).to('cuda')
|
|
compare_output = test_model(original_tensor)
|
|
|
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
|
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
|
entire_shape=original_tensor.shape,
|
|
dim_partition_dict={0: [0]})
|
|
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec)
|
|
d_tensor = DTensor(original_tensor, layout)
|
|
|
|
assert d_tensor.entire_shape == original_tensor.shape
|
|
assert d_tensor.data_type == original_tensor.dtype
|
|
|
|
if rank in (0, 1):
|
|
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
|
elif rank in (2, 3):
|
|
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
|
else:
|
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
|
assert d_tensor.to_global().equal(original_tensor)
|
|
output = test_model(d_tensor)
|
|
|
|
if rank in (0, 1):
|
|
assert output.equal(compare_output.narrow(0, 0, 2))
|
|
elif rank in (2, 3):
|
|
assert output.equal(compare_output.narrow(0, 2, 2))
|
|
else:
|
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
|
|
|
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
|
entire_shape=original_tensor.shape,
|
|
dim_partition_dict={0: [0, 1]})
|
|
new_layout = Layout(device_mesh=device_mesh,
|
|
device_type=torch.device('cuda'),
|
|
sharding_spec=new_sharding_spec,
|
|
entire_shape=original_tensor.shape)
|
|
|
|
d_tensor.layout_convert(new_layout)
|
|
|
|
if rank == 0:
|
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
|
elif rank == 1:
|
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
|
elif rank == 2:
|
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
|
elif rank == 3:
|
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
|
else:
|
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
|
|
|
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
|
|
|
if rank == 0:
|
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
|
elif rank == 1:
|
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
|
elif rank == 2:
|
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
|
elif rank == 3:
|
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
|
else:
|
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
|
|
|
|
|
def test_dtensor():
|
|
world_size = 4
|
|
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_dtensor()
|