diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py new file mode 100644 index 000000000..e311eb3ba --- /dev/null +++ b/colossalai/tensor/d_tensor/d_tensor.py @@ -0,0 +1,158 @@ +from typing import Optional + +import torch +from torch.utils._pytree import tree_map + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +class DTensor(torch.Tensor): + + def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): + self.local_tensor = local_tensor + self.data_type = local_tensor.dtype + self.entire_shape = local_tensor.shape + if dist_layout.entire_shape is None: + dist_layout.entire_shape = self.entire_shape + self.dist_layout = dist_layout + self._apply_layout() + + @staticmethod + def __new__(cls, local_tensor, layout): + return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) + + def __repr__(self): + return f"DTensor({self.to_global()}, {self.dist_layout})" + + def __str__(self): + return self.__repr__() + + def layout_convert(self, target_layout): + ''' + Convert the layout of the tensor from source_spec to target_spec. + ''' + source_spec = convert_layout_to_sharding_spec(self.dist_layout) + target_spec = convert_layout_to_sharding_spec(target_layout) + self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( + self.local_tensor, source_spec, target_spec) + self.dist_layout = target_layout + + def _apply_layout(self): + ''' + Apply the layout to the local tensor during initializing process. + ''' + source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh) + target_spec = convert_layout_to_sharding_spec(self.dist_layout) + self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( + self.local_tensor, source_spec, target_spec) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + def filter_arg(arg): + if isinstance(arg, DTensor): + return arg.local_tensor + else: + return arg + + args = tree_map(filter_arg, args) + kwargs = tree_map(filter_arg, kwargs) + # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors + # and op type. + + return func(*args, **kwargs) + + @property + def device_mesh(self): + ''' + Return the device mesh of the tensor. + ''' + return self.dist_layout.device_mesh + + @property + def sharding_spec(self): + ''' + Return the sharding specification of the tensor. + ''' + return self.dist_layout.sharding_spec + + def to(self, *args, **kwargs): + ''' + Move the tensor to a new device or convert the tensor to a new dtype. + ''' + self.local_tensor = self.local_tensor.to(*args, **kwargs) + self.data_type = self.local_tensor.dtype + self.dist_layout.device_type = self.local_tensor.device + # TODO: update the device mesh process groups or we should just cache + # both the cpu process groups and the cuda process groups? + return self + + def to_local(self): + ''' + Return the local tensor in this rank. + ''' + return self.local_tensor + + def to_global(self): + ''' + Recover the global tensor from the distributed tensor. + + Note: This function will all_gather the local tensor to the global tensor and it + will not change the layout of the DTensor. This function is mainly used for debugging or + check the correctness of the distributed tensor. + ''' + return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout)) + + +def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: + ''' + Distribute the local tensor to the distributed tensor according to the dist_layout specified. + + Args: + local_tensor: tensor to be distributed. + dist_layout: the layout specification of the distributed tensor. + + Returns: + A 'DTensor' object. + ''' + return DTensor(local_tensor, dist_layout) + + +def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: + ''' + This function converts all the parameters in the module to DTensor(DParam). + + Note: This function is subject to future change as the DParam has not been implemented yet. + ''' + for name, param in module.named_parameters(): + if param is not None and not isinstance(param, DTensor): + # TODO: we could convert the parameter to DParam here, + # the type of the parameter could be an optional argument. + setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data))) + return module + + +def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec: + ''' + Convert the layout from Layout class to ShardingSpec class. + ''' + return ShardingSpec(device_mesh=layout.device_mesh, + entire_shape=layout.entire_shape, + dim_partition_dict=layout.sharding_spec.dim_partition_dict) + + +def construct_default_sharding_spec( + tensor: torch.Tensor, + device_mesh: DeviceMesh, +) -> ShardingSpec: + ''' + Construct the default sharding specification for the tensor. + ''' + return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py new file mode 100644 index 000000000..9b72444aa --- /dev/null +++ b/colossalai/tensor/d_tensor/layout.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec + + +@dataclass +class Layout: + """Layout of a tensor. + + Attributes: + device_mesh: the device mesh to store the tensor distributedly. + device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. + sharding_spec: the sharding specification to describe how the tensor is sharded. + entire_shape: the entire shape of the global tensor. + """ + device_mesh: DeviceMesh + device_type: torch.device + sharding_spec: ShardingSpec + entire_shape: torch.Size = None diff --git a/tests/test_tensor/test_dtensor.py b/tests/test_tensor/test_dtensor.py new file mode 100644 index 000000000..1de9563a2 --- /dev/null +++ b/tests/test_tensor/test_dtensor.py @@ -0,0 +1,104 @@ +from functools import partial + +import torch +import torch.multiprocessing as mp + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.utils import free_port + + +class TestModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear_1 = torch.nn.Linear(in_features, out_features) + self.linear_2 = torch.nn.Linear(out_features, in_features) + + def forward(self, x): + x = self.linear_1(x) + x = self.linear_2(x) + return x + + +def check_dtensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_model = TestModel(8, 8).to('cuda') + original_tensor = torch.rand(4, 8).to('cuda') + compare_output = test_model(original_tensor) + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=original_tensor.shape, + dim_partition_dict={0: [0]}) + layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec) + d_tensor = DTensor(original_tensor, layout) + + assert d_tensor.entire_shape == original_tensor.shape + assert d_tensor.data_type == original_tensor.dtype + + if rank in (0, 1): + assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + elif rank in (2, 3): + assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + assert d_tensor.to_global().equal(original_tensor) + output = test_model(d_tensor) + + if rank in (0, 1): + assert output.equal(compare_output.narrow(0, 0, 2)) + elif rank in (2, 3): + assert output.equal(compare_output.narrow(0, 2, 2)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + new_sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=original_tensor.shape, + dim_partition_dict={0: [0, 1]}) + new_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=new_sharding_spec, + entire_shape=original_tensor.shape) + + d_tensor.layout_convert(new_layout) + + if rank == 0: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + elif rank == 1: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + elif rank == 2: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + elif rank == 3: + assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + dtensor_from_local = distribute_tensor(original_tensor, new_layout) + + if rank == 0: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + elif rank == 1: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + elif rank == 2: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + elif rank == 3: + assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + else: + raise ValueError(f'rank {rank} is not in the device mesh') + + +def test_dtensor(): + world_size = 4 + run_func = partial(check_dtensor, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_dtensor()