[DTensor] implementation of dtensor (#2946)

* [DTensor] implementation of dtensor

* test layout convert

* polish
pull/2951/head
YuliangLiu0306 2 years ago committed by GitHub
parent 489a9566af
commit e414e4092b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,158 @@
from typing import Optional
import torch
from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
class DTensor(torch.Tensor):
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
self.local_tensor = local_tensor
self.data_type = local_tensor.dtype
self.entire_shape = local_tensor.shape
if dist_layout.entire_shape is None:
dist_layout.entire_shape = self.entire_shape
self.dist_layout = dist_layout
self._apply_layout()
@staticmethod
def __new__(cls, local_tensor, layout):
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
def __repr__(self):
return f"DTensor({self.to_global()}, {self.dist_layout})"
def __str__(self):
return self.__repr__()
def layout_convert(self, target_layout):
'''
Convert the layout of the tensor from source_spec to target_spec.
'''
source_spec = convert_layout_to_sharding_spec(self.dist_layout)
target_spec = convert_layout_to_sharding_spec(target_layout)
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
self.local_tensor, source_spec, target_spec)
self.dist_layout = target_layout
def _apply_layout(self):
'''
Apply the layout to the local tensor during initializing process.
'''
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh)
target_spec = convert_layout_to_sharding_spec(self.dist_layout)
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
self.local_tensor, source_spec, target_spec)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def filter_arg(arg):
if isinstance(arg, DTensor):
return arg.local_tensor
else:
return arg
args = tree_map(filter_arg, args)
kwargs = tree_map(filter_arg, kwargs)
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
return func(*args, **kwargs)
@property
def device_mesh(self):
'''
Return the device mesh of the tensor.
'''
return self.dist_layout.device_mesh
@property
def sharding_spec(self):
'''
Return the sharding specification of the tensor.
'''
return self.dist_layout.sharding_spec
def to(self, *args, **kwargs):
'''
Move the tensor to a new device or convert the tensor to a new dtype.
'''
self.local_tensor = self.local_tensor.to(*args, **kwargs)
self.data_type = self.local_tensor.dtype
self.dist_layout.device_type = self.local_tensor.device
# TODO: update the device mesh process groups or we should just cache
# both the cpu process groups and the cuda process groups?
return self
def to_local(self):
'''
Return the local tensor in this rank.
'''
return self.local_tensor
def to_global(self):
'''
Recover the global tensor from the distributed tensor.
Note: This function will all_gather the local tensor to the global tensor and it
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
'''
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout))
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
'''
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
Args:
local_tensor: tensor to be distributed.
dist_layout: the layout specification of the distributed tensor.
Returns:
A 'DTensor' object.
'''
return DTensor(local_tensor, dist_layout)
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
'''
This function converts all the parameters in the module to DTensor(DParam).
Note: This function is subject to future change as the DParam has not been implemented yet.
'''
for name, param in module.named_parameters():
if param is not None and not isinstance(param, DTensor):
# TODO: we could convert the parameter to DParam here,
# the type of the parameter could be an optional argument.
setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
return module
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return ShardingSpec(device_mesh=layout.device_mesh,
entire_shape=layout.entire_shape,
dim_partition_dict=layout.sharding_spec.dim_partition_dict)
def construct_default_sharding_spec(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
'''
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={})

@ -0,0 +1,22 @@
from dataclasses import dataclass
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
@dataclass
class Layout:
"""Layout of a tensor.
Attributes:
device_mesh: the device mesh to store the tensor distributedly.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
"""
device_mesh: DeviceMesh
device_type: torch.device
sharding_spec: ShardingSpec
entire_shape: torch.Size = None

@ -0,0 +1,104 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port
class TestModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, out_features)
self.linear_2 = torch.nn.Linear(out_features, in_features)
def forward(self, x):
x = self.linear_1(x)
x = self.linear_2(x)
return x
def check_dtensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_model = TestModel(8, 8).to('cuda')
original_tensor = torch.rand(4, 8).to('cuda')
compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec)
d_tensor = DTensor(original_tensor, layout)
assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype
if rank in (0, 1):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
elif rank in (2, 3):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
assert d_tensor.to_global().equal(original_tensor)
output = test_model(d_tensor)
if rank in (0, 1):
assert output.equal(compare_output.narrow(0, 0, 2))
elif rank in (2, 3):
assert output.equal(compare_output.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
def test_dtensor():
world_size = 4
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_dtensor()
Loading…
Cancel
Save