mirror of https://github.com/hpcaitech/ColossalAI
[tensor] use communication autograd func (#1617)
* [tensor] use communication autograd func * change all to all comm spec info * rename pattern and distinguish fwd/bwd * polish codepull/1634/head
parent
c7ac0f4ab2
commit
702dbc5288
|
@ -8,6 +8,7 @@ import warnings
|
|||
from functools import reduce
|
||||
import functools
|
||||
import operator
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
|
@ -68,19 +69,16 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
|
||||
try:
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
if count_backward:
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
|
||||
input_spec, input_sharding_spec)
|
||||
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
else:
|
||||
total_resharding_cost = resharding_cost_forward
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
resharding_cost = INFINITY_COST
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import operator
|
|||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP'
|
||||
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
|
|
|
@ -18,11 +18,225 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def _all_gather(tensor, comm_spec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
tensor = tensor
|
||||
group = process_group
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor, comm_spec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor = tensor
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length)
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor, comm_spec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor, comm_spec):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
|
||||
return tensor
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is an identity operation,
|
||||
backward is all_reduce operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
comm_spec: comm_spec will give information like process group, rank list, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
ctx.comm_spec = comm_spec
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _all_reduce(grad_output, ctx.comm_spec), None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is all_reduce operation,
|
||||
backward is an identity operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
comm_spec: comm_spec will give information like process group, rank list, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _all_reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
return _all_reduce(input_, comm_spec)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is split operation,
|
||||
backward is an all gather operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
comm_spec: comm_spec will give information like process group, rank list, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
ctx.comm_spec = comm_spec
|
||||
return _split(input_, comm_spec)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _all_gather(grad_output, ctx.comm_spec), None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is an all gather operation,
|
||||
backward is split operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
comm_spec: comm_spec will give information like process group, rank list, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _all_gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
ctx.comm_spec = comm_spec
|
||||
return _all_gather(input_, comm_spec)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.comm_spec), None
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is an all to all operation,
|
||||
backward is an all to all operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
comm_spec: comm_spec will give information like process group, rank list, etc.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _all_to_all(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
sharding_spec=comm_spec.sharding_spec,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
ctx.comm_spec = comm_spec_for_backward
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
return _all_to_all(grad_outputs, ctx.comm_spec), None
|
||||
|
||||
|
||||
def reduce_grad(input_, comm_spec):
|
||||
return _ReduceGrad.apply(input_, comm_spec)
|
||||
|
||||
|
||||
def reduce_input(input_, comm_spec):
|
||||
return _ReduceInput.apply(input_, comm_spec)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, comm_spec):
|
||||
return _SplitForwardGatherBackward.apply(input_, comm_spec)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, comm_spec):
|
||||
return _GatherForwardSplitBackward.apply(input_, comm_spec)
|
||||
|
||||
|
||||
def all_to_all(input_, comm_spec):
|
||||
return _AllToAll.apply(input_, comm_spec)
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
ALLGATHER = 'all_gather'
|
||||
ALLTOALL = 'all_to_all'
|
||||
SHARD = 'shard'
|
||||
ALLREDUCE = 'all_reduce'
|
||||
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
|
||||
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
|
||||
REDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
|
||||
|
||||
|
||||
class CommSpec:
|
||||
|
@ -42,12 +256,19 @@ class CommSpec:
|
|||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
'''
|
||||
|
||||
def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None):
|
||||
def __init__(self,
|
||||
comm_pattern,
|
||||
sharding_spec,
|
||||
gather_dim=None,
|
||||
shard_dim=None,
|
||||
logical_process_axis=None,
|
||||
forward_only=False):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.sharding_spec = sharding_spec
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.forward_only = forward_only
|
||||
if isinstance(self.logical_process_axis, list):
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.logical_process_axis = 0
|
||||
|
@ -56,21 +277,24 @@ class CommSpec:
|
|||
|
||||
def __repr__(self):
|
||||
res_list = ["CommSpec:("]
|
||||
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
|
||||
res_list.append(f"comm_pattern:all_gather, ")
|
||||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
|
||||
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
|
||||
res_list.append(f"comm_pattern:all2all, ")
|
||||
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
|
||||
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||
res_list.append(f"shard_dim:{self.shard_dim}, ")
|
||||
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
|
||||
elif self.comm_pattern == CollectiveCommPattern.SHARD:
|
||||
res_list.append(f"comm_pattern:shard, ")
|
||||
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
|
||||
res_list.append(f"shard_dim:{self.shard_dim}, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
|
||||
res_list.append(f"comm_pattern:all_reduce, ")
|
||||
elif self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
|
||||
res_list.append(f"comm_pattern:REDUCE_FWD_IDENTITY_BWD, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
|
||||
return ''.join(res_list)
|
||||
|
@ -82,16 +306,38 @@ class CommSpec:
|
|||
For shard operation, it is an on-chip operation, so the communication cost is zero.
|
||||
'''
|
||||
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
|
||||
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
|
||||
return self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
|
||||
return self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
||||
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
|
||||
return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
|
||||
if self.comm_pattern == CollectiveCommPattern.SHARD:
|
||||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
# give a tiny cost to shard
|
||||
return 10
|
||||
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
|
||||
backward_communication_cost = 10
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
||||
# grad should have same shape as input tensor
|
||||
# all to all operation has same logical process axis as forward.
|
||||
backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
|
||||
forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
|
||||
backward_communication_cost = 0
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||
forward_communication_cost = 0
|
||||
backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||
# give a tiny cost to shard
|
||||
forward_communication_cost = 10
|
||||
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
try:
|
||||
if self.forward_only:
|
||||
total_communication_cost = forward_communication_cost
|
||||
else:
|
||||
total_communication_cost = forward_communication_cost + backward_communication_cost
|
||||
except:
|
||||
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
|
||||
|
||||
return total_communication_cost
|
||||
|
||||
def covert_spec_to_action(self, tensor):
|
||||
'''
|
||||
|
@ -101,64 +347,21 @@ class CommSpec:
|
|||
Argument:
|
||||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||
'''
|
||||
process_groups_list = self.device_mesh.process_groups_dict[self.logical_process_axis]
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(self.device_mesh.mesh_shape[self.logical_process_axis])
|
||||
]
|
||||
tensor = tensor
|
||||
group = process_group
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
tensor.data = torch.cat(tuple(tensor_list), self.gather_dim)
|
||||
|
||||
elif self.comm_pattern == CollectiveCommPattern.SHARD:
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor = tensor
|
||||
dim = self.shard_dim
|
||||
length = tensor.shape[self.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
tensor.data = torch.narrow(tensor, dim, start, length)
|
||||
|
||||
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[self.shard_dim] = new_shape[self.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = self.shard_dim
|
||||
length = tensor.shape[self.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
tensor.data = torch.cat(tuple(output_tensor_list), self.gather_dim)
|
||||
|
||||
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
|
||||
# For the consistency of collective communication operation, we temporally do not
|
||||
# allow all_reduce two different mesh dimensions in the same time.
|
||||
# e.g.: MatMul[(R, S01), (S01, R)] -> Partial(R, R),
|
||||
# all_reduce(Partial, logical_pg=(0, 1)) is NOT allowed, instead
|
||||
# we need to do this in two steps:
|
||||
# 1. all_reduce(Partial, logical_pg=1)
|
||||
# 2. all_reduce(Partial, logical_pg=0)
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
|
||||
tensor.data = tensor
|
||||
|
||||
if self.comm_pattern in pattern_to_func_dict:
|
||||
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
else:
|
||||
tensor.data = tensor
|
||||
|
||||
|
||||
pattern_to_func_dict = {
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
|
||||
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
|
||||
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
|
||||
CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: reduce_input,
|
||||
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapeConsistencyOptions:
|
||||
"""
|
||||
|
@ -180,6 +383,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
self._forward_only = False
|
||||
self.total_communication_cost = 0
|
||||
self.total_transform_steps = 0
|
||||
self.cached_spec_pairs_transform_path = {}
|
||||
|
@ -193,6 +397,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
assert isinstance(options_, ShapeConsistencyOptions)
|
||||
self._options = options_
|
||||
|
||||
@property
|
||||
def forward_only(self):
|
||||
return self._forward_only
|
||||
|
||||
@forward_only.setter
|
||||
def forward_only(self, value):
|
||||
assert isinstance(value, bool)
|
||||
self._forward_only = value
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
|
@ -224,7 +437,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALLGATHER
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
|
@ -240,10 +453,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
gather_dim = index
|
||||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost = comm_spec.get_comm_cost()
|
||||
|
@ -288,7 +505,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALLTOALL
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
|
@ -331,7 +548,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost = comm_spec.get_comm_cost()
|
||||
|
@ -388,7 +606,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SHARD
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
|
||||
|
@ -415,8 +633,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
logical_process_axis=logical_process_axis,
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
cost = comm_spec.get_comm_cost()
|
||||
|
|
|
@ -33,7 +33,10 @@ def check_all_gather(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLGATHER, sharding_spec, gather_dim=1, logical_process_axis=1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
|
||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||
|
@ -56,7 +59,7 @@ def check_shard(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.SHARD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
|
||||
if rank in (0, 2):
|
||||
|
@ -102,7 +105,7 @@ def check_all_to_all(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLTOALL,
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
|
@ -112,7 +115,7 @@ def check_all_to_all(device_mesh, rank):
|
|||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_all_reduce(device_mesh, rank):
|
||||
def check_all_reduce_fwd(device_mesh, rank):
|
||||
# tensor to comm
|
||||
tensor_to_comm = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
|
@ -133,8 +136,25 @@ def check_all_reduce(device_mesh, rank):
|
|||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_all_reduce_bwd(device_mesh, rank):
|
||||
# tensor to comm
|
||||
tensor_to_comm = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
tensor_to_check = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
dim_partition_dict = {}
|
||||
# DistSpec:
|
||||
# shard_sequence: R,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
@ -157,7 +177,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
@ -184,7 +204,8 @@ def check_comm(rank, world_size, port):
|
|||
check_all_to_all(device_mesh, rank)
|
||||
|
||||
# test all reduce
|
||||
check_all_reduce(device_mesh, rank)
|
||||
check_all_reduce_fwd(device_mesh, rank)
|
||||
check_all_reduce_bwd(device_mesh, rank)
|
||||
|
||||
# test all reduce in 1D flatten device mesh
|
||||
check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
|
||||
|
|
|
@ -106,18 +106,18 @@ def test_shape_consistency():
|
|||
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
|
||||
|
||||
# all-gather(S01) -> S0
|
||||
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER
|
||||
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
assert comm_action_sequence[0].gather_dim == 1
|
||||
assert comm_action_sequence[0].logical_process_axis == 1
|
||||
|
||||
# all-to-all(R, S0) -> [S0, R]
|
||||
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL
|
||||
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
assert comm_action_sequence[1].gather_dim == 1
|
||||
assert comm_action_sequence[1].shard_dim == 0
|
||||
assert comm_action_sequence[1].logical_process_axis == 0
|
||||
|
||||
# shard(S0) -> [S01]
|
||||
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD
|
||||
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue