mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel]add essential CommActions for broadcast oprands (#1793)
parent
05ce3d369f
commit
e34e850a4c
|
@ -3,10 +3,17 @@ from typing import Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
CommAction,
|
||||||
|
CommType,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||||
|
|
||||||
from ..constants import BCAST_FUNC_OP
|
from ..constants import BCAST_FUNC_OP
|
||||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||||
|
@ -81,6 +88,15 @@ class BinaryElementwiseHandler(NodeHandler):
|
||||||
physical_shape = op_data.data.shape
|
physical_shape = op_data.data.shape
|
||||||
logical_shape = op_data.logical_shape
|
logical_shape = op_data.logical_shape
|
||||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||||
sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape)
|
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||||
|
sharding_spec, logical_shape, physical_shape)
|
||||||
|
|
||||||
strategy.sharding_specs[op_data] = sharding_spec
|
strategy.sharding_specs[op_data] = sharding_spec
|
||||||
|
if len(removed_dims) > 0:
|
||||||
|
comm_action = comm_actions_for_oprands(node=self.node,
|
||||||
|
removed_dims=removed_dims,
|
||||||
|
op_data=op_data,
|
||||||
|
sharding_spec=sharding_spec)
|
||||||
|
strategy.communication_actions[op_data] = comm_action
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
@ -2,8 +2,10 @@ from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
|
||||||
|
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
|
||||||
|
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
|
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
|
||||||
|
@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
|
||||||
bias_physical_shape = bias_op_data.data.shape
|
bias_physical_shape = bias_op_data.data.shape
|
||||||
bias_logical_shape = bias_op_data.logical_shape
|
bias_logical_shape = bias_op_data.logical_shape
|
||||||
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
|
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
|
||||||
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
|
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||||
bias_physical_shape)
|
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
|
||||||
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
|
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
|
||||||
|
|
||||||
|
if len(removed_dims) > 0:
|
||||||
|
comm_action = comm_actions_for_oprands(node=self.node,
|
||||||
|
removed_dims=removed_dims,
|
||||||
|
op_data=bias_op_data,
|
||||||
|
sharding_spec=bias_sharding_spec)
|
||||||
|
strategy.communication_actions[bias_op_data] = comm_action
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
@ -213,7 +213,7 @@ class Broadcaster(BmmTransform):
|
||||||
|
|
||||||
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
||||||
|
|
||||||
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
|
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||||
logical_sharding_spec=sharding_spec,
|
logical_sharding_spec=sharding_spec,
|
||||||
logical_shape=sharding_spec.entire_shape,
|
logical_shape=sharding_spec.entire_shape,
|
||||||
physical_shape=tensor_shape_before_broadcast)
|
physical_shape=tensor_shape_before_broadcast)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
|
@ -81,8 +81,8 @@ class WhereHandler(NodeHandler):
|
||||||
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
|
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
|
||||||
logical_shape = logical_op_data_mapping[key].logical_shape
|
logical_shape = logical_op_data_mapping[key].logical_shape
|
||||||
physical_shape = physical_op_data_mapping[key].logical_shape
|
physical_shape = physical_op_data_mapping[key].logical_shape
|
||||||
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape,
|
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||||
physical_shape)
|
logical_sharding_spec, logical_shape, physical_shape)
|
||||||
strategy.sharding_specs.pop(logical_op_data_mapping[key])
|
strategy.sharding_specs.pop(logical_op_data_mapping[key])
|
||||||
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
|
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
|
||||||
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
|
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
|
from .broadcast import (
|
||||||
|
BroadcastType,
|
||||||
|
comm_actions_for_oprands,
|
||||||
|
get_broadcast_shape,
|
||||||
|
is_broadcastable,
|
||||||
|
recover_sharding_spec_for_broadcast_shape,
|
||||||
|
)
|
||||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
||||||
from .sharding import (
|
from .sharding import (
|
||||||
|
@ -13,5 +19,5 @@ __all__ = [
|
||||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,10 +2,21 @@ from enum import Enum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
CommAction,
|
||||||
|
CommType,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
|
__all__ = [
|
||||||
|
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
|
'comm_actions_for_oprands'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BroadcastType(Enum):
|
class BroadcastType(Enum):
|
||||||
|
@ -86,8 +97,11 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||||
"""
|
"""
|
||||||
# if the two shapes are the same, no broadcast occurs
|
# if the two shapes are the same, no broadcast occurs
|
||||||
# we directly return the current sharding spec
|
# we directly return the current sharding spec
|
||||||
|
|
||||||
|
# recording the sharding dimensions removed during logical shape converting to physical one
|
||||||
|
removed_dims = []
|
||||||
if list(logical_shape) == list(physical_shape):
|
if list(logical_shape) == list(physical_shape):
|
||||||
return logical_sharding_spec
|
return logical_sharding_spec, removed_dims
|
||||||
|
|
||||||
# get the number of dimensions
|
# get the number of dimensions
|
||||||
logical_num_dims = len(logical_shape)
|
logical_num_dims = len(logical_shape)
|
||||||
|
@ -104,7 +118,7 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||||
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
||||||
|
|
||||||
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||||
pass
|
removed_dims.extend(mesh_dim)
|
||||||
else:
|
else:
|
||||||
# get the corresponding physical dim
|
# get the corresponding physical dim
|
||||||
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
||||||
|
@ -114,4 +128,33 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||||
entire_shape=physical_shape,
|
entire_shape=physical_shape,
|
||||||
dim_partition_dict=physical_dim_partition)
|
dim_partition_dict=physical_dim_partition)
|
||||||
|
|
||||||
return physical_sharding_spec
|
return physical_sharding_spec, removed_dims
|
||||||
|
|
||||||
|
|
||||||
|
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
|
||||||
|
sharding_spec: ShardingSpec) -> CommAction:
|
||||||
|
"""
|
||||||
|
This method is used to generate communication actions for oprands which lose information
|
||||||
|
during convert logical shape to physical shape.
|
||||||
|
"""
|
||||||
|
if len(removed_dims) == 1:
|
||||||
|
# if list length is 1, extract element from list to avoid using flatten device mesh
|
||||||
|
removed_dims = removed_dims[0]
|
||||||
|
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
sharding_spec=sharding_spec,
|
||||||
|
logical_process_axis=removed_dims)
|
||||||
|
if op_data.type == OperationDataType.PARAM:
|
||||||
|
comm_type = CommType.HOOK
|
||||||
|
else:
|
||||||
|
comm_type = CommType.BEFORE
|
||||||
|
arg_index = -1
|
||||||
|
for index, arg in enumerate(node.args):
|
||||||
|
if op_data.name == str(arg):
|
||||||
|
arg_index = index
|
||||||
|
assert arg_index >= 0, f'op_data should be an argument of node.'
|
||||||
|
comm_action = CommAction(
|
||||||
|
comm_spec=comm_spec,
|
||||||
|
comm_type=comm_type,
|
||||||
|
arg_index=arg_index,
|
||||||
|
)
|
||||||
|
return comm_action
|
||||||
|
|
|
@ -39,8 +39,8 @@ class BiasAdditionConv(BiasAdditionModule):
|
||||||
This method is used to reshape the bias node in order to make bias and
|
This method is used to reshape the bias node in order to make bias and
|
||||||
output of non-bias convolution broadcastable.
|
output of non-bias convolution broadcastable.
|
||||||
"""
|
"""
|
||||||
bias_shape = [1] * dimensions
|
bias_shape = [1] * (dimensions - 1)
|
||||||
bias_shape[1] = -1
|
bias_shape[0] = -1
|
||||||
bias_reshape_node_kind = 'call_method'
|
bias_reshape_node_kind = 'call_method'
|
||||||
bias_reshape_node_target = 'view'
|
bias_reshape_node_target = 'view'
|
||||||
bias_reshape_node_args = (self.bias_proxy, bias_shape)
|
bias_reshape_node_args = (self.bias_proxy, bias_shape)
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
|
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||||
recover_sharding_spec_for_broadcast_shape)
|
get_broadcast_shape,
|
||||||
|
is_broadcastable,
|
||||||
|
recover_sharding_spec_for_broadcast_shape,
|
||||||
|
)
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
@ -51,8 +54,8 @@ def test_recover_sharding_spec_for_broadcast_shape():
|
||||||
1: [1]
|
1: [1]
|
||||||
},
|
},
|
||||||
entire_shape=broadcast_shape)
|
entire_shape=broadcast_shape)
|
||||||
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1,
|
physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape(
|
||||||
broadcast_shape, x1.shape)
|
logical_sharding_spec_for_x1, broadcast_shape, x1.shape)
|
||||||
print(physical_sharding_spec_for_x1)
|
print(physical_sharding_spec_for_x1)
|
||||||
|
|
||||||
assert physical_sharding_spec_for_x1.entire_shape == x1.shape
|
assert physical_sharding_spec_for_x1.entire_shape == x1.shape
|
||||||
|
|
|
@ -105,7 +105,7 @@ def test_conv_module():
|
||||||
assert weight_node._meta_data.shape == (6, 3, 2, 2)
|
assert weight_node._meta_data.shape == (6, 3, 2, 2)
|
||||||
assert bias_node._meta_data.shape == (6,)
|
assert bias_node._meta_data.shape == (6,)
|
||||||
assert conv_node._meta_data.shape == (4, 6, 63, 63)
|
assert conv_node._meta_data.shape == (4, 6, 63, 63)
|
||||||
assert view_node._meta_data.shape == (1, 6, 1, 1)
|
assert view_node._meta_data.shape == (6, 1, 1)
|
||||||
assert add_node._meta_data.shape == (4, 6, 63, 63)
|
assert add_node._meta_data.shape == (4, 6, 63, 63)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue