[autoparallel] update CommSpec (#1667)

pull/1669/head
YuliangLiu0306 2022-09-29 11:20:59 +08:00 committed by GitHub
parent 247a9dbca9
commit 3f068d1409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 413 additions and 390 deletions

View File

@ -79,7 +79,7 @@ def generate_resharding_costs(nodes: List[Node],
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
warnings.warn(f'{e}')
resharding_cost = INFINITY_COST

View File

@ -93,7 +93,7 @@ class BcastOpHandler(OperatorHandler):
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

View File

@ -91,18 +91,11 @@ class StrategyGenerator_V2(ABC):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
cost = size_per_elem_bytes * num_ele_in_comm
# compute the fwd
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}")
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
comm_cost.fwd += num_ele_in_comm['forward']
comm_cost.bwd += num_ele_in_comm['backward']
comm_cost.total += num_ele_in_comm['total']
# check if communication action exists
# if so, loop over each action and compute the cost of each action
@ -110,9 +103,6 @@ class StrategyGenerator_V2(ABC):
for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec)
# update the total cost
comm_cost.total = comm_cost.fwd + comm_cost.bwd
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy

View File

@ -9,10 +9,11 @@ from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .comm_spec import CollectiveCommPattern, CommSpec
from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec'
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
]

View File

@ -0,0 +1,358 @@
import torch
from enum import Enum
import torch.distributed as dist
from functools import reduce
import operator
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern',
'CommSpec',
]
def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
backward is all_reduce operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return input_
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output, ctx.comm_spec), None
class _ReduceInput(torch.autograd.Function):
"""
A customized communication operation which forward is all_reduce operation,
backward is an identity operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_reduce(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
return _all_reduce(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
A customized communication operation which forward is split operation,
backward is an all gather operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _split(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output, ctx.comm_spec), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
A customized communication operation which forward is an all gather operation,
backward is split operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _all_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.comm_spec), None
class _AllToAll(torch.autograd.Function):
"""
A customized communication operation which forward is an all to all operation,
backward is an all to all operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_to_all(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
ctx.comm_spec = comm_spec_for_backward
return output
@staticmethod
def backward(ctx, grad_outputs):
return _all_to_all(grad_outputs, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
def reduce_input(input_, comm_spec):
return _ReduceInput.apply(input_, comm_spec)
def split_forward_gather_backward(input_, comm_spec):
return _SplitForwardGatherBackward.apply(input_, comm_spec)
def gather_forward_split_backward(input_, comm_spec):
return _GatherForwardSplitBackward.apply(input_, comm_spec)
def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
class CommSpec:
'''
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
cost_dict = {}
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 10
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
# grad should have same shape as input tensor
# all to all operation has same logical process axis as forward.
backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
backward_communication_cost = 0
if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
forward_communication_cost = 0
backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.forward_only:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = 0
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
else:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = backward_communication_cost
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
return cost_dict
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if self.comm_pattern in pattern_to_func_dict:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
}

View File

@ -11,355 +11,9 @@ import math
from functools import reduce
import operator
from torch.distributed import ReduceOp
from .comm_spec import *
__all__ = [
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
'set_shape_consistency_options'
]
def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
backward is all_reduce operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return input_
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output, ctx.comm_spec), None
class _ReduceInput(torch.autograd.Function):
"""
A customized communication operation which forward is all_reduce operation,
backward is an identity operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_reduce(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
return _all_reduce(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
A customized communication operation which forward is split operation,
backward is an all gather operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _split(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output, ctx.comm_spec), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
A customized communication operation which forward is an all gather operation,
backward is split operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _all_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.comm_spec), None
class _AllToAll(torch.autograd.Function):
"""
A customized communication operation which forward is an all to all operation,
backward is an all to all operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@staticmethod
def symbolic(graph, input_):
return _all_to_all(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
ctx.comm_spec = comm_spec_for_backward
return output
@staticmethod
def backward(ctx, grad_outputs):
return _all_to_all(grad_outputs, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
def reduce_input(input_, comm_spec):
return _ReduceInput.apply(input_, comm_spec)
def split_forward_gather_backward(input_, comm_spec):
return _SplitForwardGatherBackward.apply(input_, comm_spec)
def gather_forward_split_backward(input_, comm_spec):
return _GatherForwardSplitBackward.apply(input_, comm_spec)
def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
class CommSpec:
'''
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 10
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
# grad should have same shape as input tensor
# all to all operation has same logical process axis as forward.
backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
backward_communication_cost = 0
if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
forward_communication_cost = 0
backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
try:
if self.forward_only:
total_communication_cost = forward_communication_cost
else:
total_communication_cost = forward_communication_cost + backward_communication_cost
except:
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
return total_communication_cost
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if self.comm_pattern in pattern_to_func_dict:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
}
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
@dataclass
@ -406,7 +60,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool)
self._forward_only = value
def get_all_all_gather_spec(self, source_spec, orig_cost):
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -463,16 +117,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost):
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -552,7 +208,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
# We won't add empty list into dim_partition_dict
@ -570,10 +226,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost):
def get_all_shard_spec(self, source_spec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -639,16 +297,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
forward_only=self.forward_only)
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
@ -665,9 +325,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
'''
valid_spec_dict = {}
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost))
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost))
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict))
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict
def shape_consistency(self, source_spec, target_spec):
@ -730,7 +390,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
total_cost: 12294.402000000002
'''
MAX_TRANSFORM_STEPS = 20
total_cost = 0
total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0}
total_steps = 0
transform_path = []
comm_action_sequence = []
@ -740,35 +400,37 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# We do nothing if the sharding spec is all the same.
if source_spec.sharding_sequence_difference(target_spec) == 0:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
return (transform_path, comm_action_sequence, total_cost_dict)
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost)
valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict)
best_difference_score = math.inf
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
comm_spec, cost = info_pairs
comm_spec, cost_dict = info_pairs
spec_difference = sharding_spec.sharding_sequence_difference(target_spec)
if spec_difference == 0:
total_cost += cost
for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + cost_dict[phase]
transform_path.append(sharding_spec)
comm_action_sequence.append(comm_spec)
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
return (transform_path, comm_action_sequence, total_cost_dict)
if spec_difference < best_difference_score:
temp_sharding_spec = sharding_spec
temp_cost = cost
temp_cost_dict = cost_dict
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
comm_action_sequence.append(temp_comm_spec)
total_cost += temp_cost
for phase, cost in total_cost_dict.items():
total_cost_dict[phase] = cost + temp_cost_dict[phase]
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")

View File

@ -27,7 +27,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[R, S1, R]' in [
str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys()
@ -48,7 +52,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[S01, R, R]' in [
str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys()
@ -72,7 +80,11 @@ def test_one_step_transform():
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, {
"forward": 0,
"backward": 0,
"total": 0
})
assert '[S01, R, R]' in [
str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()