mirror of https://github.com/hpcaitech/ColossalAI
[DTensor] implementation of dtensor (#2946)
* [DTensor] implementation of dtensor * test layout convert * polishpull/2951/head
parent
489a9566af
commit
e414e4092b
@ -0,0 +1,158 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
|
class DTensor(torch.Tensor):
|
||||||
|
|
||||||
|
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
|
||||||
|
self.local_tensor = local_tensor
|
||||||
|
self.data_type = local_tensor.dtype
|
||||||
|
self.entire_shape = local_tensor.shape
|
||||||
|
if dist_layout.entire_shape is None:
|
||||||
|
dist_layout.entire_shape = self.entire_shape
|
||||||
|
self.dist_layout = dist_layout
|
||||||
|
self._apply_layout()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, local_tensor, layout):
|
||||||
|
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"DTensor({self.to_global()}, {self.dist_layout})"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def layout_convert(self, target_layout):
|
||||||
|
'''
|
||||||
|
Convert the layout of the tensor from source_spec to target_spec.
|
||||||
|
'''
|
||||||
|
source_spec = convert_layout_to_sharding_spec(self.dist_layout)
|
||||||
|
target_spec = convert_layout_to_sharding_spec(target_layout)
|
||||||
|
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||||
|
self.local_tensor, source_spec, target_spec)
|
||||||
|
self.dist_layout = target_layout
|
||||||
|
|
||||||
|
def _apply_layout(self):
|
||||||
|
'''
|
||||||
|
Apply the layout to the local tensor during initializing process.
|
||||||
|
'''
|
||||||
|
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh)
|
||||||
|
target_spec = convert_layout_to_sharding_spec(self.dist_layout)
|
||||||
|
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||||
|
self.local_tensor, source_spec, target_spec)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
def filter_arg(arg):
|
||||||
|
if isinstance(arg, DTensor):
|
||||||
|
return arg.local_tensor
|
||||||
|
else:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
args = tree_map(filter_arg, args)
|
||||||
|
kwargs = tree_map(filter_arg, kwargs)
|
||||||
|
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
||||||
|
# and op type.
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_mesh(self):
|
||||||
|
'''
|
||||||
|
Return the device mesh of the tensor.
|
||||||
|
'''
|
||||||
|
return self.dist_layout.device_mesh
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sharding_spec(self):
|
||||||
|
'''
|
||||||
|
Return the sharding specification of the tensor.
|
||||||
|
'''
|
||||||
|
return self.dist_layout.sharding_spec
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Move the tensor to a new device or convert the tensor to a new dtype.
|
||||||
|
'''
|
||||||
|
self.local_tensor = self.local_tensor.to(*args, **kwargs)
|
||||||
|
self.data_type = self.local_tensor.dtype
|
||||||
|
self.dist_layout.device_type = self.local_tensor.device
|
||||||
|
# TODO: update the device mesh process groups or we should just cache
|
||||||
|
# both the cpu process groups and the cuda process groups?
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_local(self):
|
||||||
|
'''
|
||||||
|
Return the local tensor in this rank.
|
||||||
|
'''
|
||||||
|
return self.local_tensor
|
||||||
|
|
||||||
|
def to_global(self):
|
||||||
|
'''
|
||||||
|
Recover the global tensor from the distributed tensor.
|
||||||
|
|
||||||
|
Note: This function will all_gather the local tensor to the global tensor and it
|
||||||
|
will not change the layout of the DTensor. This function is mainly used for debugging or
|
||||||
|
check the correctness of the distributed tensor.
|
||||||
|
'''
|
||||||
|
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout))
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
||||||
|
'''
|
||||||
|
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_tensor: tensor to be distributed.
|
||||||
|
dist_layout: the layout specification of the distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 'DTensor' object.
|
||||||
|
'''
|
||||||
|
return DTensor(local_tensor, dist_layout)
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
|
||||||
|
'''
|
||||||
|
This function converts all the parameters in the module to DTensor(DParam).
|
||||||
|
|
||||||
|
Note: This function is subject to future change as the DParam has not been implemented yet.
|
||||||
|
'''
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if param is not None and not isinstance(param, DTensor):
|
||||||
|
# TODO: we could convert the parameter to DParam here,
|
||||||
|
# the type of the parameter could be an optional argument.
|
||||||
|
setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec:
|
||||||
|
'''
|
||||||
|
Convert the layout from Layout class to ShardingSpec class.
|
||||||
|
'''
|
||||||
|
return ShardingSpec(device_mesh=layout.device_mesh,
|
||||||
|
entire_shape=layout.entire_shape,
|
||||||
|
dim_partition_dict=layout.sharding_spec.dim_partition_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_default_sharding_spec(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
) -> ShardingSpec:
|
||||||
|
'''
|
||||||
|
Construct the default sharding specification for the tensor.
|
||||||
|
'''
|
||||||
|
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={})
|
@ -0,0 +1,22 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Layout:
|
||||||
|
"""Layout of a tensor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
device_mesh: the device mesh to store the tensor distributedly.
|
||||||
|
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||||||
|
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||||||
|
entire_shape: the entire shape of the global tensor.
|
||||||
|
"""
|
||||||
|
device_mesh: DeviceMesh
|
||||||
|
device_type: torch.device
|
||||||
|
sharding_spec: ShardingSpec
|
||||||
|
entire_shape: torch.Size = None
|
@ -0,0 +1,104 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.tracer import ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = torch.nn.Linear(in_features, out_features)
|
||||||
|
self.linear_2 = torch.nn.Linear(out_features, in_features)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def check_dtensor(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
test_model = TestModel(8, 8).to('cuda')
|
||||||
|
original_tensor = torch.rand(4, 8).to('cuda')
|
||||||
|
compare_output = test_model(original_tensor)
|
||||||
|
|
||||||
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||||
|
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||||
|
entire_shape=original_tensor.shape,
|
||||||
|
dim_partition_dict={0: [0]})
|
||||||
|
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec)
|
||||||
|
d_tensor = DTensor(original_tensor, layout)
|
||||||
|
|
||||||
|
assert d_tensor.entire_shape == original_tensor.shape
|
||||||
|
assert d_tensor.data_type == original_tensor.dtype
|
||||||
|
|
||||||
|
if rank in (0, 1):
|
||||||
|
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
||||||
|
elif rank in (2, 3):
|
||||||
|
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
assert d_tensor.to_global().equal(original_tensor)
|
||||||
|
output = test_model(d_tensor)
|
||||||
|
|
||||||
|
if rank in (0, 1):
|
||||||
|
assert output.equal(compare_output.narrow(0, 0, 2))
|
||||||
|
elif rank in (2, 3):
|
||||||
|
assert output.equal(compare_output.narrow(0, 2, 2))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
|
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||||
|
entire_shape=original_tensor.shape,
|
||||||
|
dim_partition_dict={0: [0, 1]})
|
||||||
|
new_layout = Layout(device_mesh=device_mesh,
|
||||||
|
device_type=torch.device('cuda'),
|
||||||
|
sharding_spec=new_sharding_spec,
|
||||||
|
entire_shape=original_tensor.shape)
|
||||||
|
|
||||||
|
d_tensor.layout_convert(new_layout)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||||
|
elif rank == 1:
|
||||||
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||||
|
elif rank == 2:
|
||||||
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||||
|
elif rank == 3:
|
||||||
|
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
|
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||||
|
elif rank == 1:
|
||||||
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||||
|
elif rank == 2:
|
||||||
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||||
|
elif rank == 3:
|
||||||
|
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtensor():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dtensor()
|
Loading…
Reference in new issue