mirror of https://github.com/hpcaitech/ColossalAI
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components * polishpull/3135/head
parent
ed8f60b93b
commit
2eca4cd376
|
@ -3,12 +3,11 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from .layout import Layout
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
from .layout_converter import LayoutConverter, to_global
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
|
from .sharding_spec import ShardingSpec
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
||||||
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
layout_converter = LayoutConverter()
|
||||||
|
|
||||||
|
|
||||||
class DTensor(torch.Tensor):
|
class DTensor(torch.Tensor):
|
||||||
|
@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
|
||||||
self.local_tensor = local_tensor
|
self.local_tensor = local_tensor
|
||||||
self.data_type = local_tensor.dtype
|
self.data_type = local_tensor.dtype
|
||||||
self.entire_shape = local_tensor.shape
|
self.entire_shape = local_tensor.shape
|
||||||
if dist_layout.entire_shape is None:
|
|
||||||
dist_layout.entire_shape = self.entire_shape
|
|
||||||
self.dist_layout = dist_layout
|
self.dist_layout = dist_layout
|
||||||
self._apply_layout()
|
self._apply_layout()
|
||||||
|
|
||||||
|
@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
|
||||||
'''
|
'''
|
||||||
Convert the layout of the tensor from source_spec to target_spec.
|
Convert the layout of the tensor from source_spec to target_spec.
|
||||||
'''
|
'''
|
||||||
source_spec = convert_layout_to_sharding_spec(self.dist_layout)
|
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
|
||||||
target_spec = convert_layout_to_sharding_spec(target_layout)
|
|
||||||
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
|
|
||||||
self.local_tensor, source_spec, target_spec)
|
|
||||||
self.dist_layout = target_layout
|
self.dist_layout = target_layout
|
||||||
|
|
||||||
def _apply_layout(self):
|
def _apply_layout(self):
|
||||||
'''
|
'''
|
||||||
Apply the layout to the local tensor during initializing process.
|
Apply the layout to the local tensor during initializing process.
|
||||||
'''
|
'''
|
||||||
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh)
|
source_spec = construct_default_sharding_spec(self.local_tensor)
|
||||||
target_spec = convert_layout_to_sharding_spec(self.dist_layout)
|
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
|
||||||
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
|
device_type=self.dist_layout.device_type,
|
||||||
self.local_tensor, source_spec, target_spec)
|
sharding_spec=source_spec,
|
||||||
|
entire_shape=self.entire_shape)
|
||||||
|
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
|
@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
|
||||||
will not change the layout of the DTensor. This function is mainly used for debugging or
|
will not change the layout of the DTensor. This function is mainly used for debugging or
|
||||||
check the correctness of the distributed tensor.
|
check the correctness of the distributed tensor.
|
||||||
'''
|
'''
|
||||||
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout))
|
return to_global(self.local_tensor, self.dist_layout)
|
||||||
|
|
||||||
|
|
||||||
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
||||||
|
@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec:
|
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||||
'''
|
|
||||||
Convert the layout from Layout class to ShardingSpec class.
|
|
||||||
'''
|
|
||||||
return ShardingSpec(device_mesh=layout.device_mesh,
|
|
||||||
entire_shape=layout.entire_shape,
|
|
||||||
dim_partition_dict=layout.sharding_spec.dim_partition_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def construct_default_sharding_spec(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
device_mesh: DeviceMesh,
|
|
||||||
) -> ShardingSpec:
|
|
||||||
'''
|
'''
|
||||||
Construct the default sharding specification for the tensor.
|
Construct the default sharding specification for the tensor.
|
||||||
'''
|
'''
|
||||||
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={})
|
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||||
|
|
|
@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
|
||||||
@dataclass
|
@dataclass
|
||||||
class LayoutConverterOptions:
|
class LayoutConverterOptions:
|
||||||
"""
|
"""
|
||||||
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency.
|
LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
|
||||||
"""
|
"""
|
||||||
# TODO: layout converter option is not implemented yet
|
# TODO: layout converter option is not implemented yet
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||||
shape_consistency_manager = LayoutConverter()
|
layout_converter = LayoutConverter()
|
||||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||||
device_type=layout.device_type,
|
device_type=layout.device_type,
|
||||||
sharding_spec=global_sharding_spec,
|
sharding_spec=global_sharding_spec,
|
||||||
entire_shape=layout.entire_shape)
|
entire_shape=layout.entire_shape)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout)
|
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||||
return global_tensor
|
return global_tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,11 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer import ColoTracer
|
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
|
||||||
compare_output = test_model(original_tensor)
|
compare_output = test_model(original_tensor)
|
||||||
|
|
||||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||||
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||||
entire_shape=original_tensor.shape,
|
|
||||||
dim_partition_dict={0: [0]})
|
|
||||||
layout = Layout(device_mesh=device_mesh,
|
layout = Layout(device_mesh=device_mesh,
|
||||||
device_type=torch.device('cuda'),
|
device_type=torch.device('cuda'),
|
||||||
sharding_spec=target_sharding_spec,
|
sharding_spec=target_sharding_spec,
|
||||||
|
@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||||
entire_shape=original_tensor.shape,
|
|
||||||
dim_partition_dict={0: [0, 1]})
|
|
||||||
new_layout = Layout(device_mesh=device_mesh,
|
new_layout = Layout(device_mesh=device_mesh,
|
||||||
device_type=torch.device('cuda'),
|
device_type=torch.device('cuda'),
|
||||||
sharding_spec=new_sharding_spec,
|
sharding_spec=new_sharding_spec,
|
||||||
|
|
Loading…
Reference in New Issue