mirror of https://github.com/hpcaitech/ColossalAI
[tensor] add shape consistency feature to support auto spec transform (#1418)
* [tensor] add shape consistency feature to supportauto sharding spec transform. * [tensor] remove unused argument in simulator, add doc string for target pair.pull/1424/head^2
parent
4fb3c52cf0
commit
33f0744d51
|
@ -0,0 +1,320 @@
|
|||
import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
ALLGATHER = 'all_gather'
|
||||
ALLTOALL = 'all_to_all'
|
||||
SHARD = 'shard'
|
||||
|
||||
|
||||
class ShapeConsistencyManager:
|
||||
|
||||
def __init__(self, consistency_option=None):
|
||||
self.consistency_option = consistency_option
|
||||
self.total_communication_cost = 0
|
||||
self.total_transform_steps = 0
|
||||
self.cached_spec_pairs = {}
|
||||
|
||||
def _all_gather_simulator(self, target_pair):
|
||||
'''
|
||||
Simulating all-gather operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||
Therefore, all gather operation just remove the last element in shard list,
|
||||
e.g.:
|
||||
all-gather(S01) -> S0
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, shard_list = target_pair
|
||||
new_shard_list = shard_list[:-1]
|
||||
# TODO: compute comm cost
|
||||
comm_cost = 0
|
||||
return new_shard_list, comm_cost
|
||||
|
||||
def _all_to_all_simulator(self, f_target_pair, b_target_pair):
|
||||
'''
|
||||
Simulating all-to-all operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, f_shard_list = f_target_pair
|
||||
_, b_shard_list = b_target_pair
|
||||
if not len(b_shard_list):
|
||||
b_shard_list.extend(f_shard_list)
|
||||
f_shard_list = []
|
||||
else:
|
||||
f_shard_list.extend(b_shard_list)
|
||||
b_shard_list = []
|
||||
# TODO: compute comm cost
|
||||
comm_cost = 0
|
||||
return f_shard_list, b_shard_list, comm_cost
|
||||
|
||||
def _shard_simulator(self, target_pair, legal_sharding_dims):
|
||||
'''
|
||||
Simulating shard operation, analyze the communication cost(always ZERO)
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
shard(R) -> S0
|
||||
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
|
||||
e.g:
|
||||
shard(S0) -> S01
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, shard_list = target_pair
|
||||
shard_list_list = []
|
||||
for dim in legal_sharding_dims:
|
||||
if len(shard_list) != 0 and dim <= shard_list[-1]:
|
||||
continue
|
||||
new_shard_list = shard_list + [dim]
|
||||
shard_list_list.append(new_shard_list)
|
||||
comm_cost = 0
|
||||
return shard_list_list, comm_cost
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: R,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list, cost = self._all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
valid_spec_dict[new_sharding_spec] = orig_cost + cost
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: R,S1,S0
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
# skip (R, R) cases
|
||||
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
f_shard_list, b_shard_list, cost = self._all_to_all_simulator(f_target_pair, b_target_pair)
|
||||
f_index = f_target_pair[0]
|
||||
b_index = b_target_pair[0]
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
new_dim_partition_dict[f_index] = f_shard_list
|
||||
new_dim_partition_dict[b_index] = b_shard_list
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
valid_spec_dict[new_sharding_spec] = orig_cost + cost
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
if len(legal_sharding_dims) == 0:
|
||||
return valid_spec_dict
|
||||
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for index in range(tensor_dims):
|
||||
if index not in source_spec.dim_partition_dict:
|
||||
shard_list_list, cost = self._shard_simulator((index, []), legal_sharding_dims)
|
||||
else:
|
||||
shard_list_list, cost = self._shard_simulator((index, source_spec.dim_partition_dict[index]),
|
||||
legal_sharding_dims)
|
||||
if not shard_list_list:
|
||||
continue
|
||||
for shard_list in shard_list_list:
|
||||
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
valid_spec_dict[new_sharding_spec] = orig_cost + cost
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
Note:
|
||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||
we could safely put them together.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
|
||||
return valid_spec_dict
|
||||
|
||||
def shape_consistency(self, source_spec, target_spec):
|
||||
'''
|
||||
This method will find a path to transform source_spec to target_spec with
|
||||
a greedy algorithm.
|
||||
The basic idea is:
|
||||
Step1:
|
||||
Generate all one-step transform sequences from source_spec.
|
||||
Step2:
|
||||
Pick the 'best' sharding spec following the heuristic function.
|
||||
Step3:
|
||||
Repeat above steps until the source spec transform to target spec.
|
||||
|
||||
This function is NOT completed, due to absense of difference function.
|
||||
'''
|
||||
MAX_TRANSFORM_STEPS = 10
|
||||
total_cost = 0
|
||||
total_steps = 0
|
||||
transform_path = []
|
||||
temp_sharding_spec = deepcopy(source_spec)
|
||||
transform_path.append(temp_sharding_spec)
|
||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||
valid_transform_spec_dict = get_all_one_step_transform_spec(temp_sharding_spec)
|
||||
best_difference_score = 0
|
||||
for sharding_spec, cost in valid_transform_spec_dict.items():
|
||||
if no_difference(sharding_spec, target_spec):
|
||||
total_cost += cost
|
||||
transform_path.append(sharding_spec)
|
||||
return (transform_path, total_cost)
|
||||
if difference(sharding_spec, target_spec) > best_difference_score:
|
||||
temp_sharding_spec = deepcopy(sharding_spec)
|
||||
temp_cost = cost
|
||||
transform_path.append(temp_sharding_spec)
|
||||
total_cost += temp_cost
|
||||
return (transform_path, total_cost)
|
|
@ -13,7 +13,7 @@ class _DimSpec:
|
|||
'''
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = shard_list is None
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -52,12 +52,16 @@ class ShardingSpec:
|
|||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict):
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self._sanity_check()
|
||||
self.sharding_sequence = self.convert_dict_to_shard_sequence()
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["DistSpec:"]
|
||||
|
@ -80,10 +84,19 @@ class ShardingSpec:
|
|||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
sharding_sequence = [_DimSpec(None)] * len(self.entire_shape)
|
||||
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
sharding_sequence[dim] = _DimSpec(shard_list)
|
||||
return sharding_sequence
|
||||
self.sharding_sequence = sharding_sequence
|
||||
|
||||
def convert_shard_sequence_to_dict(self):
|
||||
new_dim_partition_dict = {}
|
||||
for index, dim_spec in enumerate(self.sharding_sequence):
|
||||
if not dim_spec.is_replica:
|
||||
if index not in new_dim_partition_dict:
|
||||
new_dim_partition_dict[index] = []
|
||||
new_dim_partition_dict[index].append(dim_spec.shard_list)
|
||||
self.dim_partition_dict = new_dim_partition_dict
|
||||
|
||||
def sharding_sequence_difference(self, other):
|
||||
'''
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
def test_shape_consistency():
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8, 6))
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
# {DistSpec:
|
||||
# shard_sequence: R,S1,R
|
||||
# device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4): 0}
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||
|
||||
assert '[R, S1, R]' in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
assert '[S0, R, R]' in [
|
||||
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
|
||||
]
|
||||
|
||||
dim_partition_dict_all2all = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
|
||||
# {DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
# shard_sequence: R,S1,S0
|
||||
# device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): 0}
|
||||
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[R, S1, S0]' in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
|
||||
]
|
||||
|
||||
dim_partition_shard = {0: [0]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
|
||||
# {DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
# shard_sequence: S0,R,S1
|
||||
# device_mesh_shape: (4, 4): 0}
|
||||
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
|
||||
|
||||
assert '[S01, R, R]' in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, S1, R]' in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
assert '[S0, R, S1]' in [
|
||||
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shape_consistency()
|
Loading…
Reference in New Issue