[DTensor] refactor dtensor with new components (#3089)

* [DTensor] refactor dtensor with new components

* polish
pull/3135/head
YuliangLiu0306 2023-03-14 16:25:47 +08:00 committed by GitHub
parent ed8f60b93b
commit 2eca4cd376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 41 deletions

View File

@ -3,12 +3,11 @@ from typing import Optional
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh from .layout import Layout
from colossalai.tensor.d_tensor.layout import Layout from .layout_converter import LayoutConverter, to_global
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global from .sharding_spec import ShardingSpec
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager() layout_converter = LayoutConverter()
class DTensor(torch.Tensor): class DTensor(torch.Tensor):
@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
self.local_tensor = local_tensor self.local_tensor = local_tensor
self.data_type = local_tensor.dtype self.data_type = local_tensor.dtype
self.entire_shape = local_tensor.shape self.entire_shape = local_tensor.shape
if dist_layout.entire_shape is None:
dist_layout.entire_shape = self.entire_shape
self.dist_layout = dist_layout self.dist_layout = dist_layout
self._apply_layout() self._apply_layout()
@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
''' '''
Convert the layout of the tensor from source_spec to target_spec. Convert the layout of the tensor from source_spec to target_spec.
''' '''
source_spec = convert_layout_to_sharding_spec(self.dist_layout) self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
target_spec = convert_layout_to_sharding_spec(target_layout)
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
self.local_tensor, source_spec, target_spec)
self.dist_layout = target_layout self.dist_layout = target_layout
def _apply_layout(self): def _apply_layout(self):
''' '''
Apply the layout to the local tensor during initializing process. Apply the layout to the local tensor during initializing process.
''' '''
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh) source_spec = construct_default_sharding_spec(self.local_tensor)
target_spec = convert_layout_to_sharding_spec(self.dist_layout) source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( device_type=self.dist_layout.device_type,
self.local_tensor, source_spec, target_spec) sharding_spec=source_spec,
entire_shape=self.entire_shape)
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
will not change the layout of the DTensor. This function is mainly used for debugging or will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor. check the correctness of the distributed tensor.
''' '''
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout)) return to_global(self.local_tensor, self.dist_layout)
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
return module return module
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec: def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return ShardingSpec(device_mesh=layout.device_mesh,
entire_shape=layout.entire_shape,
dim_partition_dict=layout.sharding_spec.dim_partition_dict)
def construct_default_sharding_spec(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
) -> ShardingSpec:
''' '''
Construct the default sharding specification for the tensor. Construct the default sharding specification for the tensor.
''' '''
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={}) return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})

View File

@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
@dataclass @dataclass
class LayoutConverterOptions: class LayoutConverterOptions:
""" """
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency. LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
""" """
# TODO: layout converter option is not implemented yet # TODO: layout converter option is not implemented yet
pass pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
shape_consistency_manager = LayoutConverter() layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh, global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type, device_type=layout.device_type,
sharding_spec=global_sharding_spec, sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape) entire_shape=layout.entire_shape)
with torch.no_grad(): with torch.no_grad():
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout) global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor return global_tensor

View File

@ -4,12 +4,11 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port from colossalai.utils import free_port
@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
compare_output = test_model(original_tensor) compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(device_mesh=device_mesh, target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'), device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec, sharding_spec=target_sharding_spec,
@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(device_mesh=device_mesh, new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh, new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'), device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,