[autoparallel]add essential CommActions for broadcast oprands (#1793)

pull/1792/head
YuliangLiu0306 2022-11-04 18:36:42 +08:00 committed by GitHub
parent 05ce3d369f
commit e34e850a4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 102 additions and 24 deletions

View File

@ -3,10 +3,17 @@ from typing import Dict, List, Union
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP from ..constants import BCAST_FUNC_OP
from ..utils import recover_sharding_spec_for_broadcast_shape from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
@ -81,6 +88,15 @@ class BinaryElementwiseHandler(NodeHandler):
physical_shape = op_data.data.shape physical_shape = op_data.data.shape
logical_shape = op_data.logical_shape logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape) sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs[op_data] = sharding_spec strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=op_data,
sharding_spec=sharding_spec)
strategy.communication_actions[op_data] = comm_action
return strategy return strategy

View File

@ -2,8 +2,10 @@ from typing import Dict, List, Union
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..utils import recover_sharding_spec_for_broadcast_shape
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
bias_physical_shape = bias_op_data.data.shape bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape, bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_physical_shape) bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy return strategy

View File

@ -213,7 +213,7 @@ class Broadcaster(BmmTransform):
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None] tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape( physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec, logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape, logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast) physical_shape=tensor_shape_before_broadcast)

View File

@ -4,7 +4,7 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector) from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import recover_sharding_spec_for_broadcast_shape from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
@ -81,8 +81,8 @@ class WhereHandler(NodeHandler):
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
logical_shape = logical_op_data_mapping[key].logical_shape logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape, physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
physical_shape) logical_sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"

View File

@ -1,4 +1,10 @@
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape from .broadcast import (
BroadcastType,
comm_actions_for_oprands,
get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from .factory import generate_resharding_costs, generate_sharding_spec from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception from .misc import check_sharding_spec_validity, ignore_sharding_exception
from .sharding import ( from .sharding import (
@ -13,5 +19,5 @@ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size' 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
] ]

View File

@ -2,10 +2,21 @@ from enum import Enum, auto
from typing import List from typing import List
import torch import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
)
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] __all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
]
class BroadcastType(Enum): class BroadcastType(Enum):
@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
""" """
# if the two shapes are the same, no broadcast occurs # if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec # we directly return the current sharding spec
# recording the sharding dimensions removed during logical shape converting to physical one
removed_dims = []
if list(logical_shape) == list(physical_shape): if list(logical_shape) == list(physical_shape):
return logical_sharding_spec return logical_sharding_spec, removed_dims
# get the number of dimensions # get the number of dimensions
logical_num_dims = len(logical_shape) logical_num_dims = len(logical_shape)
@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_broadcast_type = logical_dim_broadcast_info[shape_dim] logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
pass removed_dims.extend(mesh_dim)
else: else:
# get the corresponding physical dim # get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim) physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
entire_shape=physical_shape, entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition) dim_partition_dict=physical_dim_partition)
return physical_sharding_spec return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
"""
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
comm_type = CommType.BEFORE
arg_index = -1
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
arg_index=arg_index,
)
return comm_action

View File

@ -39,8 +39,8 @@ class BiasAdditionConv(BiasAdditionModule):
This method is used to reshape the bias node in order to make bias and This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable. output of non-bias convolution broadcastable.
""" """
bias_shape = [1] * dimensions bias_shape = [1] * (dimensions - 1)
bias_shape[1] = -1 bias_shape[0] = -1
bias_reshape_node_kind = 'call_method' bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view' bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, bias_shape) bias_reshape_node_args = (self.bias_proxy, bias_shape)

View File

@ -1,7 +1,10 @@
import torch import torch
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable, from colossalai.auto_parallel.tensor_shard.utils import (
recover_sharding_spec_for_broadcast_shape) get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape():
1: [1] 1: [1]
}, },
entire_shape=broadcast_shape) entire_shape=broadcast_shape)
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1, physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape(
broadcast_shape, x1.shape) logical_sharding_spec_for_x1, broadcast_shape, x1.shape)
print(physical_sharding_spec_for_x1) print(physical_sharding_spec_for_x1)
assert physical_sharding_spec_for_x1.entire_shape == x1.shape assert physical_sharding_spec_for_x1.entire_shape == x1.shape

View File

@ -105,7 +105,7 @@ def test_conv_module():
assert weight_node._meta_data.shape == (6, 3, 2, 2) assert weight_node._meta_data.shape == (6, 3, 2, 2)
assert bias_node._meta_data.shape == (6,) assert bias_node._meta_data.shape == (6,)
assert conv_node._meta_data.shape == (4, 6, 63, 63) assert conv_node._meta_data.shape == (4, 6, 63, 63)
assert view_node._meta_data.shape == (1, 6, 1, 1) assert view_node._meta_data.shape == (6, 1, 1)
assert add_node._meta_data.shape == (4, 6, 63, 63) assert add_node._meta_data.shape == (4, 6, 63, 63)