mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel]add essential CommActions for broadcast oprands (#1793)
parent
05ce3d369f
commit
e34e850a4c
|
@ -3,10 +3,17 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||
|
@ -81,6 +88,15 @@ class BinaryElementwiseHandler(NodeHandler):
|
|||
physical_shape = op_data.data.shape
|
||||
logical_shape = op_data.logical_shape
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape)
|
||||
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
sharding_spec, logical_shape, physical_shape)
|
||||
|
||||
strategy.sharding_specs[op_data] = sharding_spec
|
||||
if len(removed_dims) > 0:
|
||||
comm_action = comm_actions_for_oprands(node=self.node,
|
||||
removed_dims=removed_dims,
|
||||
op_data=op_data,
|
||||
sharding_spec=sharding_spec)
|
||||
strategy.communication_actions[op_data] = comm_action
|
||||
|
||||
return strategy
|
||||
|
|
|
@ -2,8 +2,10 @@ from typing import Dict, List, Union
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
|
||||
|
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
|
|||
bias_physical_shape = bias_op_data.data.shape
|
||||
bias_logical_shape = bias_op_data.logical_shape
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
|
||||
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
|
||||
bias_physical_shape)
|
||||
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
|
||||
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
|
||||
|
||||
if len(removed_dims) > 0:
|
||||
comm_action = comm_actions_for_oprands(node=self.node,
|
||||
removed_dims=removed_dims,
|
||||
op_data=bias_op_data,
|
||||
sharding_spec=bias_sharding_spec)
|
||||
strategy.communication_actions[bias_op_data] = comm_action
|
||||
|
||||
return strategy
|
||||
|
|
|
@ -213,7 +213,7 @@ class Broadcaster(BmmTransform):
|
|||
|
||||
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
||||
|
||||
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
|
||||
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec=sharding_spec,
|
||||
logical_shape=sharding_spec.entire_shape,
|
||||
physical_shape=tensor_shape_before_broadcast)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
|
@ -81,8 +81,8 @@ class WhereHandler(NodeHandler):
|
|||
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
|
||||
logical_shape = logical_op_data_mapping[key].logical_shape
|
||||
physical_shape = physical_op_data_mapping[key].logical_shape
|
||||
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape,
|
||||
physical_shape)
|
||||
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec, logical_shape, physical_shape)
|
||||
strategy.sharding_specs.pop(logical_op_data_mapping[key])
|
||||
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
|
||||
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
|
||||
from .broadcast import (
|
||||
BroadcastType,
|
||||
comm_actions_for_oprands,
|
||||
get_broadcast_shape,
|
||||
is_broadcastable,
|
||||
recover_sharding_spec_for_broadcast_shape,
|
||||
)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
||||
from .sharding import (
|
||||
|
@ -13,5 +19,5 @@ __all__ = [
|
|||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
|
||||
]
|
||||
|
|
|
@ -2,10 +2,21 @@ from enum import Enum, auto
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
)
|
||||
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
|
||||
__all__ = [
|
||||
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'comm_actions_for_oprands'
|
||||
]
|
||||
|
||||
|
||||
class BroadcastType(Enum):
|
||||
|
@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
|||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
|
||||
# recording the sharding dimensions removed during logical shape converting to physical one
|
||||
removed_dims = []
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
return logical_sharding_spec, removed_dims
|
||||
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
|
@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
|||
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
||||
|
||||
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||
pass
|
||||
removed_dims.extend(mesh_dim)
|
||||
else:
|
||||
# get the corresponding physical dim
|
||||
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
||||
|
@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
|||
entire_shape=physical_shape,
|
||||
dim_partition_dict=physical_dim_partition)
|
||||
|
||||
return physical_sharding_spec
|
||||
return physical_sharding_spec, removed_dims
|
||||
|
||||
|
||||
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
|
||||
sharding_spec: ShardingSpec) -> CommAction:
|
||||
"""
|
||||
This method is used to generate communication actions for oprands which lose information
|
||||
during convert logical shape to physical shape.
|
||||
"""
|
||||
if len(removed_dims) == 1:
|
||||
# if list length is 1, extract element from list to avoid using flatten device mesh
|
||||
removed_dims = removed_dims[0]
|
||||
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
sharding_spec=sharding_spec,
|
||||
logical_process_axis=removed_dims)
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
comm_type = CommType.HOOK
|
||||
else:
|
||||
comm_type = CommType.BEFORE
|
||||
arg_index = -1
|
||||
for index, arg in enumerate(node.args):
|
||||
if op_data.name == str(arg):
|
||||
arg_index = index
|
||||
assert arg_index >= 0, f'op_data should be an argument of node.'
|
||||
comm_action = CommAction(
|
||||
comm_spec=comm_spec,
|
||||
comm_type=comm_type,
|
||||
arg_index=arg_index,
|
||||
)
|
||||
return comm_action
|
||||
|
|
|
@ -39,8 +39,8 @@ class BiasAdditionConv(BiasAdditionModule):
|
|||
This method is used to reshape the bias node in order to make bias and
|
||||
output of non-bias convolution broadcastable.
|
||||
"""
|
||||
bias_shape = [1] * dimensions
|
||||
bias_shape[1] = -1
|
||||
bias_shape = [1] * (dimensions - 1)
|
||||
bias_shape[0] = -1
|
||||
bias_reshape_node_kind = 'call_method'
|
||||
bias_reshape_node_target = 'view'
|
||||
bias_reshape_node_args = (self.bias_proxy, bias_shape)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
|
||||
recover_sharding_spec_for_broadcast_shape)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
get_broadcast_shape,
|
||||
is_broadcastable,
|
||||
recover_sharding_spec_for_broadcast_shape,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape():
|
|||
1: [1]
|
||||
},
|
||||
entire_shape=broadcast_shape)
|
||||
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1,
|
||||
broadcast_shape, x1.shape)
|
||||
physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec_for_x1, broadcast_shape, x1.shape)
|
||||
print(physical_sharding_spec_for_x1)
|
||||
|
||||
assert physical_sharding_spec_for_x1.entire_shape == x1.shape
|
||||
|
|
|
@ -105,7 +105,7 @@ def test_conv_module():
|
|||
assert weight_node._meta_data.shape == (6, 3, 2, 2)
|
||||
assert bias_node._meta_data.shape == (6,)
|
||||
assert conv_node._meta_data.shape == (4, 6, 63, 63)
|
||||
assert view_node._meta_data.shape == (1, 6, 1, 1)
|
||||
assert view_node._meta_data.shape == (6, 1, 1)
|
||||
assert add_node._meta_data.shape == (4, 6, 63, 63)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue