ColossalAI/colossalai/tensor/shape_consistency.py

760 lines
35 KiB
Python

import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
from .comm_spec import *
__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"]
@dataclass
class ShapeConsistencyOptions:
"""
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
"""
# TODO: shape consistency option is not implemented yet
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(
distributed_tensor, sharding_spec, global_sharding_spec
)
return global_tensor
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
"""
manager = ShapeConsistencyManager()
manager.options = options
class ShapeConsistencyManager(metaclass=SingletonMeta):
def __init__(self):
self._options = None
self._forward_only = False
self.total_communication_cost = 0
self.total_transform_steps = 0
self.cached_spec_pairs_transform_path = {}
@property
def options(self):
return self._options
@options.setter
def options(self, options_: ShapeConsistencyOptions):
assert isinstance(options_, ShapeConsistencyOptions)
self._options = options_
@property
def forward_only(self):
return self._forward_only
@forward_only.setter
def forward_only(self, value):
assert isinstance(value, bool)
self._forward_only = value
def get_all_all_gather_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-gather operation, we just care about the S dimension.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
Example:
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: R,S1,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,R
device_mesh_shape: (4, 4): 0}
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if shard_list:
new_dim_partition_dict[index] = shard_list
else:
new_dim_partition_dict.pop(index)
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
gather_dim = index
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_all_to_all_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-to-all operation, we just care about the pairs containing S dimension.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
Example:
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: R,S1,S0
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
# skip (R, R) cases
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S01, R) -> (R, S01) is NOT allowed
if len(source_spec.dim_partition_dict[f_index]) >= 2:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
# skip (S1, S0) -> S10
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]:
continue
f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair)
f_index = f_target_pair[0]
b_index = b_target_pair[0]
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
if len(f_shard_list) < len(f_target_pair[1]):
gather_dim = f_index
shard_dim = b_index
logical_process_axis = f_target_pair[1][-1]
else:
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if f_shard_list:
new_dim_partition_dict[f_index] = f_shard_list
else:
new_dim_partition_dict.pop(f_index)
if b_shard_list:
new_dim_partition_dict[b_index] = b_shard_list
else:
new_dim_partition_dict.pop(b_index)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
"""
Get all valid sharding specs from source_spec with single shard operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the sharding operation, we just care about legal sharding dimensions.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
Example:
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,S1,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
if len(legal_sharding_dims) == 0:
return valid_spec_dict
tensor_dims = len(source_spec.entire_shape)
for index in range(tensor_dims):
if index not in source_spec.dim_partition_dict:
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
else:
shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims)
if not shard_list_list:
continue
for shard_list in shard_list_list:
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
new_dim_partition_dict[index] = shard_list
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis,
forward_only=self.forward_only,
)
# compute the communication cost with CommSpec
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_mix_gather_spec(
self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]
) -> Dict[ShardingSpec, float]:
"""
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
"""
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S10, R) -> (R, R)
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S10) -> (R, R)
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True,
)
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = {}
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(
source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict
)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
"""
Get all valid sharding specs from source_spec with one step transform, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
Note:
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
we could safely put them together.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
"""
valid_spec_dict = {}
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
"""memory cost of the communication action sequence
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
Returns:
TrainCycleItem: memory (numel) cost of such comm_action_sequence
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_gather memory footprint
all_gather will allocate memory for the output tensor, and there will be temp memory for
all_gather operation, which is twice the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
generate new tensor in this case, so no memory will be allocated.
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
# if we don't shard the tensor on the first dimension, the split action will
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:
alloc_numel -= input_numel
else:
# if we shard the tensor on the first dimension, the split action will not generate
# a new tensor, and as it will preserve a reference to the input tensor, we could
# override the discard_input option here
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
# actions in the comm actions sequence, the first split action operate on the second dimension,
# the second split action operate on the first dimension, and the third split action operate, again,
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
# memory the same size as the output of first split action. However, the third split action will discard
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
# kind of weird, and I think we could ignore it for now.
pass
return alloc_numel, peak_numel
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
return alloc_numel, peak_numel
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
the temp memory is three times the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
else:
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
return alloc_numel, peak_numel
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
}
fwd_actions = []
bwd_actions = []
# construct forward and backward comm actions sequence
for comm_spec in comm_action_sequence:
comm_spec: CommSpec
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
fwd_actions.append(fwd_action)
bwd_actions.append(bwd_action)
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel = 0
fwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
fwd_alloc_numel, fwd_peak_numel = (
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
if idx == 0
else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_alloc_numel, bwd_peak_numel = (
bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel)
if idx == 0
else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
def shape_consistency(
self, source_spec: ShardingSpec, target_spec: ShardingSpec
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
"""
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
The basic idea is:
Step1:
Generate all one-step transform sequences from source_spec.
Step2:
Pick the 'best' sharding spec following the heuristic function.
Step3:
Repeat above steps until the source spec transform to target spec.
During finding the transform path, communication cost will be accumulated, and it
will be finally used in auto parallel solver.
Additionally, to avoid repeating the path search in runtime, we cached all solved path
in auto parallel strategy building time, which could handle most of cases in runtime.
Argument:
source_spec(ShardingSpec): ShardingSpec of the source activation.
target_spec(ShardingSpec): ShardingSpec of the target activation.
Return:
transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec,
it contains the source_spec and target_spec.
comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order.
total_cost(float): total cost to complete shape consistency transform.
Example:
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target)
print(f'transform_path: {transform_path}')
print(f'comm_action_sequence: {comm_action_sequence}')
print(f'total_cost: {total_cost}')
output:
transform_path: [DistSpec:
shard_sequence: R,S01,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: R,S0,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: S0,R,R
device_mesh_shape: (4, 4), DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4)]
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
total_cost: 12294.402000000002
"""
MAX_TRANSFORM_STEPS = 20
total_cost_dict = {"forward": 0, "backward": 0, "total": 0}
total_steps = 0
transform_path = []
comm_action_sequence = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
self.cached_spec_pairs_transform_path[spec_pairs] = (None, None)
# We do nothing if the sharding spec is all the same.
if source_spec.sharding_sequence_difference(target_spec) == 0:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost_dict)
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict)
best_difference_score = math.inf
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
comm_spec, cost_dict = info_pairs
spec_difference = sharding_spec.sharding_sequence_difference(target_spec)
if spec_difference == 0:
for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + cost_dict[phase]
transform_path.append(sharding_spec)
comm_action_sequence.append(comm_spec)
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost_dict)
if spec_difference < best_difference_score:
temp_sharding_spec = sharding_spec
temp_cost_dict = cost_dict
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
comm_action_sequence.append(temp_comm_spec)
for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + temp_cost_dict[phase]
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
"""
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
Argument:
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
Output in rank0 and rank2:
tensor([[0.],
[0.],
[2.],
[2.]])
Output in rank1 and rank3:
tensor([[1.],
[1.],
[3.],
[3.]])
"""
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec
return tensor_with_sharding_spec
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.sharding_spec = target_spec
return tensor