[tensor] shape consistency generate transform path and communication cost (#1435)

* [tensor] shape consistency output transform path and communication cost

* polish code
pull/1449/head
YuliangLiu0306 2022-08-12 14:02:32 +08:00 committed by GitHub
parent 5774fe0270
commit 0f3042363c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 533 additions and 134 deletions

View File

@ -1,7 +1,11 @@
import torch
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
import math
from functools import reduce
import operator
class CollectiveCommPattern(Enum):
@ -10,96 +14,71 @@ class CollectiveCommPattern(Enum):
SHARD = 'shard'
class CommSpec:
'''
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, optional): The gather_dim of the tensor will be gathered.
shard_dim(int, optional): The shard_dim of the tensor will be sharded.
logical_process_axis(int, optional): The mesh_dim to implement the communication action.
'''
def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
res_list.append(f"comm_pattern:allgather, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
res_list.append(f"comm_pattern:all2all, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
else:
res_list.append(f"comm_pattern:shard, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather and all2all operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
return 0
def covert_spec_to_action(self):
pass
class ShapeConsistencyManager:
def __init__(self, consistency_option=None):
self.consistency_option = consistency_option
self.total_communication_cost = 0
self.total_transform_steps = 0
self.cached_spec_pairs = {}
def _all_gather_simulator(self, target_pair):
'''
Simulating all-gather operation, analyze the communication cost
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
all-gather(S01) -> S0
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
new_shard_list = shard_list[:-1]
# TODO: compute comm cost
comm_cost = 0
return new_shard_list, comm_cost
def _all_to_all_simulator(self, f_target_pair, b_target_pair):
'''
Simulating all-to-all operation, analyze the communication cost
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, f_shard_list = f_target_pair
_, b_shard_list = b_target_pair
if not len(b_shard_list):
b_shard_list.extend(f_shard_list)
f_shard_list = []
else:
f_shard_list.extend(b_shard_list)
b_shard_list = []
# TODO: compute comm cost
comm_cost = 0
return f_shard_list, b_shard_list, comm_cost
def _shard_simulator(self, target_pair, legal_sharding_dims):
'''
Simulating shard operation, analyze the communication cost(always ZERO)
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
shard(R) -> S0
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
e.g:
shard(S0) -> S01
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list = []
for dim in legal_sharding_dims:
if len(shard_list) != 0 and dim <= shard_list[-1]:
continue
new_shard_list = shard_list + [dim]
shard_list_list.append(new_shard_list)
comm_cost = 0
return shard_list_list, comm_cost
self.cached_spec_pairs_transform_path = {}
def get_all_all_gather_spec(self, source_spec, orig_cost):
'''
@ -132,15 +111,35 @@ class ShapeConsistencyManager:
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALLGATHER
for target_pair in source_spec.dim_partition_dict.items():
shard_list, cost = self._all_gather_simulator(target_pair)
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
new_dim_partition_dict[index] = shard_list
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if shard_list:
new_dim_partition_dict[index] = shard_list
else:
new_dim_partition_dict.pop(index)
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
gather_dim = index
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axis)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
# generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = orig_cost + cost
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost):
@ -176,6 +175,7 @@ class ShapeConsistencyManager:
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALLTOALL
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
@ -184,24 +184,62 @@ class ShapeConsistencyManager:
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S01, R) -> (R, S01) is NOT allowed
if len(source_spec.dim_partition_dict[f_index]) >= 2:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, R) -> (R, S01) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
f_shard_list, b_shard_list, cost = self._all_to_all_simulator(f_target_pair, b_target_pair)
# skip (S1, S0) -> S10
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]:
continue
f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair)
f_index = f_target_pair[0]
b_index = b_target_pair[0]
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
if len(f_shard_list) < len(f_target_pair[1]):
gather_dim = f_index
shard_dim = b_index
logical_process_axis = f_target_pair[1][-1]
else:
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
new_dim_partition_dict[f_index] = f_shard_list
new_dim_partition_dict[b_index] = b_shard_list
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if f_shard_list:
new_dim_partition_dict[f_index] = f_shard_list
else:
new_dim_partition_dict.pop(f_index)
if b_shard_list:
new_dim_partition_dict[b_index] = b_shard_list
else:
new_dim_partition_dict.pop(b_index)
# generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = orig_cost + cost
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost):
@ -237,6 +275,9 @@ class ShapeConsistencyManager:
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SHARD
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
@ -247,19 +288,31 @@ class ShapeConsistencyManager:
tensor_dims = len(source_spec.entire_shape)
for index in range(tensor_dims):
if index not in source_spec.dim_partition_dict:
shard_list_list, cost = self._shard_simulator((index, []), legal_sharding_dims)
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
else:
shard_list_list, cost = self._shard_simulator((index, source_spec.dim_partition_dict[index]),
legal_sharding_dims)
shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims)
if not shard_list_list:
continue
for shard_list in shard_list_list:
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
new_dim_partition_dict[index] = shard_list
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
# generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = orig_cost + cost
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
@ -296,25 +349,93 @@ class ShapeConsistencyManager:
Step3:
Repeat above steps until the source spec transform to target spec.
This function is NOT completed, due to absense of difference function.
During finding the transform path, commucation cost will be accumulated, and it
will be finally used in auto parallel solver.
Additionally, to avoid repeating the path search in runtime, we cached all solved path
in auto parallel strategy building time, which could handle most of cases in runtime.
Argument:
source_spec(ShardingSpec): ShardingSpec of the source activation.
target_spec(ShardingSpec): ShardingSpec of the target activation.
Return:
transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec,
it contains the source_spec and target_spec.
comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order.
total_cost(float): total cost to complete shape consistency transform.
Example:
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target)
print(f'transform_path: {transform_path}')
print(f'comm_action_sequence: {comm_action_sequence}')
print(f'total_cost: {total_cost}')
output:
transform_path: [DistSpec:
shard_sequence: R,S01,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: R,S0,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: S0,R,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4)]
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
total_cost: 12294.402000000002
'''
MAX_TRANSFORM_STEPS = 10
MAX_TRANSFORM_STEPS = 20
total_cost = 0
total_steps = 0
transform_path = []
comm_action_sequence = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
self.cached_spec_pairs_transform_path[spec_pairs] = (None, None)
# We do nothing if the sharding spec is all the same.
if source_spec.sharding_sequence_difference(target_spec) == 0:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
temp_sharding_spec = deepcopy(source_spec)
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = get_all_one_step_transform_spec(temp_sharding_spec)
best_difference_score = 0
for sharding_spec, cost in valid_transform_spec_dict.items():
if no_difference(sharding_spec, target_spec):
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost)
best_difference_score = math.inf
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
comm_spec, cost = info_pairs
spec_difference = sharding_spec.sharding_sequence_difference(target_spec)
if spec_difference == 0:
total_cost += cost
transform_path.append(sharding_spec)
return (transform_path, total_cost)
if difference(sharding_spec, target_spec) > best_difference_score:
comm_action_sequence.append(comm_spec)
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
if spec_difference < best_difference_score:
temp_sharding_spec = deepcopy(sharding_spec)
temp_cost = cost
temp_comm_spec = deepcopy(comm_spec)
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
comm_action_sequence.append(temp_comm_spec)
total_cost += temp_cost
return (transform_path, total_cost)
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")

View File

@ -1,4 +1,15 @@
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'
class _DimSpec:
@ -15,6 +26,7 @@ class _DimSpec:
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
@ -27,11 +39,101 @@ class _DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
'''
Conver str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
if str_spec == 'R':
return []
if str_spec == 'S0':
return [0]
if str_spec == 'S1':
return [1]
if str_spec == 'S01':
return [0, 1]
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
source_spec_list = ['R', 'S0', 'S1', 'S01']
target_spec_list = ['R', 'S0', 'S1', 'S01']
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
legal_sharding_dims = []
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif len(source_shard_list
) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
difference = ALLGATHER_COST
# shard(source) -> target
elif len(source_shard_list) == len(
target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
-1] not in source_shard_list:
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[spec_pair] = difference
self.difference_dict = difference_dict
def difference(self, other):
'''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature.
The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
'''
pass
difference = self.difference_dict[(str(self), str(other))]
return difference
class ShardingSpec:
@ -43,8 +145,9 @@ class ShardingSpec:
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
dim_partition_dict(Dict[int, List[int]]): The key is the dimension of tensor to be sharded,
dim_partition_dict(Dict[int, List[int]] optional): The key is the dimension of tensor to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
@ -79,12 +182,18 @@ class ShardingSpec:
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
'''
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = _DimSpec(shard_list)
self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self):
'''
Convert sharding_sequence into dim_partition_dict.
'''
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
@ -95,6 +204,45 @@ class ShardingSpec:
def sharding_sequence_difference(self, other):
'''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature.
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
'''
pass
assert len(self.sharding_sequence) == len(
other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.difference(other_dim_spec)
return difference
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
sharded_shape[dim] //= shard_partitions
return torch.Size(sharded_shape)

View File

@ -5,6 +5,90 @@ import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor
def all_gather_simulator(target_pair):
'''
Simulating all-gather operation, analyze the communication cost
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
all-gather(S01) -> S0
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
new_shard_list = shard_list[:-1]
return new_shard_list
def all_to_all_simulator(f_target_pair, b_target_pair):
'''
Simulating all-to-all operation, analyze the communication cost
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, f_shard_list = f_target_pair
_, b_shard_list = b_target_pair
if not len(b_shard_list):
b_shard_list.extend(f_shard_list)
f_shard_list = []
else:
f_shard_list.extend(b_shard_list)
b_shard_list = []
return f_shard_list, b_shard_list
def shard_simulator(target_pair, legal_sharding_dims):
'''
Simulating shard operation, analyze the communication cost(always ZERO)
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
shard(R) -> S0
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
e.g:
shard(S0) -> S01
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list = []
for dim in legal_sharding_dims:
if len(shard_list) != 0 and dim <= shard_list[-1]:
continue
new_shard_list = shard_list + [dim]
shard_list_list.append(new_shard_list)
return shard_list_list
# The function is credited to PyTorch Team
def named_params_with_colotensor(
module: nn.Module,

View File

@ -1,29 +1,32 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((64, 32, 16))
shape_consistency_manager = ShapeConsistencyManager()
def test_one_step_transform():
def test_shape_consistency():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
# {DistSpec:
# shard_sequence: R,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4): 0}
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
assert '[R, S1, R]' in [
@ -39,12 +42,12 @@ def test_shape_consistency():
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
# {DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: R,S1,S0
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0}
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec:
# shard_sequence: R,S1,S0
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
assert '[S01, R, R]' in [
@ -63,12 +66,12 @@ def test_shape_consistency():
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
# {DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0}
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
assert '[S01, R, R]' in [
@ -82,5 +85,48 @@ def test_shape_consistency():
]
def test_shape_consistency():
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
sharding_spec_source, sharding_spec_target)
transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER
assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL
assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][0] == transform_path
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][1] == comm_action_sequence
if __name__ == '__main__':
test_one_step_transform()
test_shape_consistency()