mirror of https://github.com/hpcaitech/ColossalAI
[DTensor] implement layout converter (#3055)
* [DTensor] refactor LayoutConverter for DTensor * polish code * polish docstringpull/3085/head
parent
89aa7926ac
commit
8e4e8601b7
|
@ -0,0 +1,556 @@
|
|||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
|
||||
from .sharding_spec import ShardingSpec
|
||||
from .utils import get_comm_cost
|
||||
|
||||
__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options']
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayoutConverterOptions:
|
||||
"""
|
||||
LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency.
|
||||
"""
|
||||
# TODO: layout converter option is not implemented yet
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
shape_consistency_manager = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
"""
|
||||
manager = LayoutConverter()
|
||||
manager.options = options
|
||||
|
||||
|
||||
class LayoutConverter(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
self._forward_only = False
|
||||
self.cached_solution = {}
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
@options.setter
|
||||
def options(self, options_: LayoutConverterOptions):
|
||||
assert isinstance(options_, LayoutConverterOptions)
|
||||
self._options = options_
|
||||
|
||||
@property
|
||||
def forward_only(self):
|
||||
return self._forward_only
|
||||
|
||||
@forward_only.setter
|
||||
def forward_only(self, value):
|
||||
assert isinstance(value, bool)
|
||||
self._forward_only = value
|
||||
|
||||
def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
Get all valid layouts from source_layout with single all-gather operation.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
|
||||
Argument:
|
||||
source_layout: the layout to be transformed.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-gather operation.
|
||||
|
||||
Example:
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
|
||||
|
||||
Output:
|
||||
[R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0)
|
||||
[S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
|
||||
# We won't add empty list into dim_partition_dict
|
||||
# The key will be popped if the related shard_list is empty
|
||||
if shard_list:
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(index)
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
gather_dim = index
|
||||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_groups_dict=process_groups_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
Get all valid layouts from source_layout with single all-to-all operation.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
|
||||
Argument:
|
||||
source_layout(Layout): the layout to be transformed.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-to-all operation.
|
||||
|
||||
Example:
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
rst_dict = layout_converter.all_to_all_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
|
||||
|
||||
Output:
|
||||
[S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1)
|
||||
[R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0)
|
||||
[S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1)
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
source_spec = source_layout.sharding_spec
|
||||
tensor_dims = source_spec.dims
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
# skip (R, R) cases
|
||||
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
# skip (S01, R) -> (R, S01) is NOT allowed
|
||||
if len(source_spec.dim_partition_dict[f_index]) >= 2:
|
||||
continue
|
||||
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
# skip (R, S01) -> (S01, R) is NOT allowed
|
||||
if len(source_spec.dim_partition_dict[b_index]) >= 2:
|
||||
continue
|
||||
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
# skip (S1, S0) -> S10
|
||||
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]:
|
||||
continue
|
||||
f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair)
|
||||
f_index = f_target_pair[0]
|
||||
b_index = b_target_pair[0]
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
if len(f_shard_list) < len(f_target_pair[1]):
|
||||
gather_dim = f_index
|
||||
shard_dim = b_index
|
||||
logical_process_axis = f_target_pair[1][-1]
|
||||
else:
|
||||
gather_dim = b_index
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
|
||||
# We won't add empty list into dim_partition_dict
|
||||
# The key will be popped if the related shard_list is empty
|
||||
if f_shard_list:
|
||||
new_dim_partition_dict[f_index] = f_shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(f_index)
|
||||
if b_shard_list:
|
||||
new_dim_partition_dict[b_index] = b_shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(b_index)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
Get all valid layouts from source_layout with single shard operation.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
|
||||
Argument:
|
||||
source_layout(Layout): the layout to be transformed.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single shard operation.
|
||||
|
||||
Example:
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
rst_dict = layout_converter.shard_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
|
||||
|
||||
Output:
|
||||
[S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1)
|
||||
[S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
|
||||
[S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1)
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
|
||||
if len(legal_sharding_dims) == 0:
|
||||
return valid_spec_dict
|
||||
|
||||
tensor_dims = source_spec.dims
|
||||
|
||||
for index in range(tensor_dims):
|
||||
if index not in source_spec.dim_partition_dict:
|
||||
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
|
||||
else:
|
||||
shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims)
|
||||
if not shard_list_list:
|
||||
continue
|
||||
for shard_list in shard_list_list:
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(dim_size=source_spec.dims,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
|
||||
'''
|
||||
Get all valid layouts from source_layout with one step transform.
|
||||
|
||||
Note:
|
||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||
we could safely put them together.
|
||||
|
||||
Argument:
|
||||
source_layout(Layout): the layout to be transformer.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform.
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
valid_spec_dict.update(self.all_gather_transform_layouts(source_layout))
|
||||
valid_spec_dict.update(self.all_to_all_transform_layout(source_layout))
|
||||
valid_spec_dict.update(self.shard_transform_layout(source_layout))
|
||||
return valid_spec_dict
|
||||
|
||||
def layout_converting(self, source_layout: Layout,
|
||||
target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]:
|
||||
'''
|
||||
This method will find a path to transform source_layout to target_layout with
|
||||
a greedy algorithm.
|
||||
The basic idea is:
|
||||
Step1:
|
||||
Generate all one-step transform sequences from source_layout.
|
||||
Step2:
|
||||
Pick the 'best' layout following the heuristic function.
|
||||
Step3:
|
||||
Repeat above steps until the source layout transform to target layout.
|
||||
|
||||
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
||||
in auto parallel strategy building time, which could handle most of cases in runtime.
|
||||
|
||||
Args:
|
||||
source_layout(Layout): the layout to be transformed.
|
||||
target_layout(Layout): the layout to be achieved after a serious of transforms.
|
||||
|
||||
Return:
|
||||
transform_path(List[Layout]): The transform path from source_layout to target_layout,
|
||||
it contains the source_layout and target_layout.
|
||||
comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the layout converting in order.
|
||||
|
||||
Example:
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
|
||||
# [R,S01,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# [S01,R,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
print(transform_path_str)
|
||||
|
||||
output:
|
||||
[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]
|
||||
'''
|
||||
source_spec = source_layout.sharding_spec
|
||||
target_spec = target_layout.sharding_spec
|
||||
MAX_TRANSFORM_STEPS = 20
|
||||
total_steps = 0
|
||||
transform_path = []
|
||||
comm_action_sequence = []
|
||||
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
|
||||
|
||||
if spec_pairs in self.cached_solution:
|
||||
return self.cached_solution[spec_pairs]
|
||||
|
||||
# We do nothing if the sharding spec is all the same.
|
||||
if source_spec.spec_diff(target_spec) == 0:
|
||||
self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
|
||||
return (
|
||||
transform_path,
|
||||
comm_action_sequence,
|
||||
)
|
||||
|
||||
temp_sharding_layout = source_layout
|
||||
|
||||
transform_path.append(temp_sharding_layout)
|
||||
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_layout)
|
||||
best_difference_score = math.inf
|
||||
|
||||
for layout, comm_spec in valid_transform_spec_dict.items():
|
||||
sharding_spec = layout.sharding_spec
|
||||
spec_difference = sharding_spec.spec_diff(target_spec)
|
||||
|
||||
if spec_difference == 0:
|
||||
transform_path.append(layout)
|
||||
comm_action_sequence.append(comm_spec)
|
||||
self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
|
||||
return (transform_path, comm_action_sequence)
|
||||
|
||||
if spec_difference < best_difference_score:
|
||||
temp_sharding_layout = layout
|
||||
temp_comm_spec = comm_spec
|
||||
best_difference_score = spec_difference
|
||||
|
||||
transform_path.append(temp_sharding_layout)
|
||||
comm_action_sequence.append(temp_comm_spec)
|
||||
|
||||
total_steps += 1
|
||||
|
||||
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
|
||||
|
||||
def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]:
|
||||
'''
|
||||
Get the total communication cost of the layout converting process.
|
||||
'''
|
||||
transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0}
|
||||
for layout, comm_spec in zip(transform_path, comm_action_sequence):
|
||||
cost_dict = get_comm_cost(layout, comm_spec, self.forward_only)
|
||||
for key in total_cost:
|
||||
total_cost[key] += cost_dict[key]
|
||||
return total_cost
|
||||
|
||||
def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor:
|
||||
'''
|
||||
Apply target_layout to tensor with source layout, the transform path is generated by the
|
||||
layout_converting method.
|
||||
|
||||
Argument:
|
||||
tensor (torch.Tensor): The tensor to be redistributed.
|
||||
source_layout(Layout): The source layout of the tensor.
|
||||
target_layout (Layout): The tensor will be redistributed to the target_layout.
|
||||
|
||||
Example:
|
||||
layout_converter = LayoutConverter()
|
||||
dim_partition_source = {0: [0]}
|
||||
dim_partition_target = {1: [0]}
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# [R,S0,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
if rank in (0, 1):
|
||||
sharded_tensor_0 = torch.zeros(2, 1)
|
||||
sharded_tensor_1 = torch.ones(2, 1)
|
||||
# tensor([[0., 1.],
|
||||
# [0., 1.]])
|
||||
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
|
||||
if rank in (2, 3):
|
||||
sharded_tensor_0 = torch.ones(2, 1) * 2
|
||||
sharded_tensor_1 = torch.ones(2, 1) * 3
|
||||
# tensor([[2., 3.],
|
||||
# [2., 3.]])
|
||||
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
|
||||
|
||||
# converted_tensor: [R, S0, R]
|
||||
converted_tensor = layout_converter.apply(tensor_to_comm, source_layout, target_layout)
|
||||
print(converted_tensor)
|
||||
|
||||
Output in rank0 and rank1:
|
||||
tensor([[0.],
|
||||
[0.],
|
||||
[2.],
|
||||
[2.]])
|
||||
|
||||
Output in rank2 and rank3:
|
||||
tensor([[1.],
|
||||
[1.],
|
||||
[3.],
|
||||
[3.]])
|
||||
'''
|
||||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
return tensor
|
|
@ -0,0 +1,66 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
|
||||
def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]:
|
||||
'''
|
||||
This method is used to compute the communication cost for a given layout and comm_spec.
|
||||
|
||||
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
|
||||
compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is a tiny cost.
|
||||
|
||||
Args:
|
||||
layout: the layout of the tensor.
|
||||
comm_spec: the comm_spec to instruct the communication operation.
|
||||
forward_only: if it is True, we will just count the forward communication cost.
|
||||
If it is False, we will count both forward and backward communication cost.
|
||||
'''
|
||||
comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1)
|
||||
device_mesh = layout.device_mesh
|
||||
comm_pattern = comm_spec.comm_pattern
|
||||
logical_process_axis = comm_spec.logical_process_axis
|
||||
cost_dict = {}
|
||||
|
||||
if comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||
# the comm size for all gather is the size of the gathered tensor
|
||||
gather_dim = comm_spec.gather_dim
|
||||
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
|
||||
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
|
||||
comm_size_for_all_gather = comm_size * all_gather_size
|
||||
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
|
||||
# give a tiny cost to shard
|
||||
backward_communication_cost = 100
|
||||
|
||||
if comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||
forward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
|
||||
# grad should have same shape as input tensor
|
||||
# all to all operation has same logical process axis as forward.
|
||||
backward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
|
||||
|
||||
if comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
|
||||
forward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
|
||||
backward_communication_cost = 0
|
||||
|
||||
if comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||
forward_communication_cost = 0
|
||||
backward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
|
||||
|
||||
if comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||
# give a tiny cost to shard
|
||||
forward_communication_cost = 100
|
||||
backward_communication_cost = device_mesh.all_gather_cost(comm_size, logical_process_axis)
|
||||
|
||||
if forward_only:
|
||||
cost_dict["forward"] = forward_communication_cost
|
||||
cost_dict["backward"] = 0
|
||||
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
|
||||
else:
|
||||
cost_dict["forward"] = forward_communication_cost
|
||||
cost_dict["backward"] = backward_communication_cost
|
||||
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
|
||||
|
||||
return cost_dict
|
|
@ -0,0 +1,206 @@
|
|||
import math
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
mesh_shape = (2, 2)
|
||||
|
||||
|
||||
def check_one_step_transform(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
|
||||
]
|
||||
assert '[S0, R, R]' in [
|
||||
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
|
||||
]
|
||||
|
||||
dim_partition_dict_all2all = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
|
||||
layout_all2all = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_all2all,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[R, S1, S0]' in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
|
||||
]
|
||||
|
||||
dim_partition_shard = {0: [0]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
|
||||
shard_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_shard,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, S1, R]' in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
|
||||
]
|
||||
|
||||
|
||||
def check_layout_converting(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
|
||||
# check transform path
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
|
||||
|
||||
# check comm action sequence
|
||||
# all-gather(S01) -> S0
|
||||
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
assert comm_action_sequence[0].gather_dim == 1
|
||||
assert comm_action_sequence[0].logical_process_axis == 1
|
||||
|
||||
# all-to-all(R, S0) -> [S0, R]
|
||||
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
assert comm_action_sequence[1].gather_dim == 1
|
||||
assert comm_action_sequence[1].shard_dim == 0
|
||||
assert comm_action_sequence[1].logical_process_axis == 0
|
||||
|
||||
# shard(S0) -> [S01]
|
||||
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
# checkout chached_spec_pairs_transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
||||
|
||||
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
|
||||
|
||||
assert comm_cost['forward'] == comm_cost['backward']
|
||||
assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
|
||||
|
||||
|
||||
def check_layout_converting_apply(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
original_tensor = torch.rand(entire_shape).cuda()
|
||||
|
||||
# tensor_to_apply: [R, S01, R]
|
||||
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
|
||||
|
||||
# tensor_to_check: [S01, R, R]
|
||||
tensor_to_check = original_tensor.narrow(0, rank * 16, 16)
|
||||
|
||||
converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout)
|
||||
assert converted_tensor.equal(tensor_to_check)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_layout_converter():
|
||||
world_size = 4
|
||||
run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_layout_converter()
|
Loading…
Reference in New Issue