[autoparallel]add essential CommActions for broadcast oprands (#1793)

pull/1792/head
YuliangLiu0306 2022-11-04 18:36:42 +08:00 committed by GitHub
parent 05ce3d369f
commit e34e850a4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 102 additions and 24 deletions

View File

@ -3,10 +3,17 @@ from typing import Dict, List, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import recover_sharding_spec_for_broadcast_shape
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
@ -81,6 +88,15 @@ class BinaryElementwiseHandler(NodeHandler):
physical_shape = op_data.data.shape
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=op_data,
sharding_spec=sharding_spec)
strategy.communication_actions[op_data] = comm_action
return strategy

View File

@ -2,8 +2,10 @@ from typing import Dict, List, Union
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
bias_physical_shape)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy

View File

@ -213,7 +213,7 @@ class Broadcaster(BmmTransform):
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)

View File

@ -4,7 +4,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
@ -81,8 +81,8 @@ class WhereHandler(NodeHandler):
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape,
physical_shape)
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"

View File

@ -1,4 +1,10 @@
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
from .broadcast import (
BroadcastType,
comm_actions_for_oprands,
get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception
from .sharding import (
@ -13,5 +19,5 @@ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
]

View File

@ -2,10 +2,21 @@ from enum import Enum, auto
from typing import List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
)
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
]
class BroadcastType(Enum):
@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
# recording the sharding dimensions removed during logical shape converting to physical one
removed_dims = []
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
return logical_sharding_spec, removed_dims
# get the number of dimensions
logical_num_dims = len(logical_shape)
@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
pass
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
return physical_sharding_spec
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
"""
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
comm_type = CommType.BEFORE
arg_index = -1
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
arg_index=arg_index,
)
return comm_action

View File

@ -39,8 +39,8 @@ class BiasAdditionConv(BiasAdditionModule):
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape = [1] * dimensions
bias_shape[1] = -1
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, bias_shape)

View File

@ -1,7 +1,10 @@
import torch
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
recover_sharding_spec_for_broadcast_shape)
from colossalai.auto_parallel.tensor_shard.utils import (
get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape():
1: [1]
},
entire_shape=broadcast_shape)
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1,
broadcast_shape, x1.shape)
physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec_for_x1, broadcast_shape, x1.shape)
print(physical_sharding_spec_for_x1)
assert physical_sharding_spec_for_x1.entire_shape == x1.shape

View File

@ -105,7 +105,7 @@ def test_conv_module():
assert weight_node._meta_data.shape == (6, 3, 2, 2)
assert bias_node._meta_data.shape == (6,)
assert conv_node._meta_data.shape == (4, 6, 63, 63)
assert view_node._meta_data.shape == (1, 6, 1, 1)
assert view_node._meta_data.shape == (6, 1, 1)
assert add_node._meta_data.shape == (4, 6, 63, 63)