[DTensor] refactor dtensor with new components (#3089)

* [DTensor] refactor dtensor with new components

* polish
pull/3135/head
YuliangLiu0306 2023-03-14 16:25:47 +08:00 committed by GitHub
parent ed8f60b93b
commit 2eca4cd376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 41 deletions

View File

@ -3,12 +3,11 @@ from typing import Optional
import torch
from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
from colossalai.tensor.sharding_spec import ShardingSpec
from .layout import Layout
from .layout_converter import LayoutConverter, to_global
from .sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
layout_converter = LayoutConverter()
class DTensor(torch.Tensor):
@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
self.local_tensor = local_tensor
self.data_type = local_tensor.dtype
self.entire_shape = local_tensor.shape
if dist_layout.entire_shape is None:
dist_layout.entire_shape = self.entire_shape
self.dist_layout = dist_layout
self._apply_layout()
@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
'''
Convert the layout of the tensor from source_spec to target_spec.
'''
source_spec = convert_layout_to_sharding_spec(self.dist_layout)
target_spec = convert_layout_to_sharding_spec(target_layout)
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
self.local_tensor, source_spec, target_spec)
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
self.dist_layout = target_layout
def _apply_layout(self):
'''
Apply the layout to the local tensor during initializing process.
'''
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh)
target_spec = convert_layout_to_sharding_spec(self.dist_layout)
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
self.local_tensor, source_spec, target_spec)
source_spec = construct_default_sharding_spec(self.local_tensor)
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
device_type=self.dist_layout.device_type,
sharding_spec=source_spec,
entire_shape=self.entire_shape)
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
'''
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout))
return to_global(self.local_tensor, self.dist_layout)
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
return module
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return ShardingSpec(device_mesh=layout.device_mesh,
entire_shape=layout.entire_shape,
dim_partition_dict=layout.sharding_spec.dim_partition_dict)
def construct_default_sharding_spec(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
) -> ShardingSpec:
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
'''
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={})
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})

View File

@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
@dataclass
class LayoutConverterOptions:
"""
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency.
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
"""
# TODO: layout converter option is not implemented yet
pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
shape_consistency_manager = LayoutConverter()
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout)
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor

View File

@ -4,12 +4,11 @@ import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port
@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0]})
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
else:
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0, 1]})
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,