diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index b29ff3a65..0e3ea670c 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -223,7 +223,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): node.args = new_args elif isinstance(getitem_index, (tuple, list)): - assert isinstance(getitem_index[0], slice) + if not isinstance(getitem_index[0], slice): + continue new_slice_items = [] for slice_item in getitem_index: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index f510f7477..e8ae363e9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler'] @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(MetaInfoNodeHandler): +class BinaryElementwiseHandler(NodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 2795c8544..0aeb2e0d4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -7,7 +7,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) +from colossalai.logging import get_dist_logger from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import FollowingStrategyGenerator @@ -69,39 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] + getitem_index = self.op_data['index'].data for index, strategy in enumerate(self.predecessor_node.strategies_vector): - dim_partition_dict_mapping = {} - communication_action_mapping = {} - dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict - dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) - gather_input = 0 in dim_partition_dict_for_input - if gather_input: - logical_process_axis = dim_partition_dict_for_output.pop(0) - - shift_dim_partition_dict_for_output = {} - for dim, mesh_dim_list in dim_partition_dict_for_output.items(): - shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list - dim_partition_dict_for_output = shift_dim_partition_dict_for_output - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - if gather_input: - input_communication_action = self.get_communication_action( - sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=logical_process_axis, - comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping["input"] = input_communication_action - - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - + try: + logger = get_dist_logger() + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_for_input = copy.deepcopy( + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) + + int_index = False + if isinstance(getitem_index, int): + int_index = True + getitem_dims = [ + 0, + ] + shift_length = 1 + elif isinstance(getitem_index, slice): + getitem_dims = [ + 0, + ] + else: + getitem_dims = [i for i in range(len(getitem_index))] + if isinstance(getitem_index[0], int): + int_index = True + shift_length = len(getitem_index) + + gather_dims = [] + for dim in getitem_dims: + if dim in dim_partition_dict_for_input: + gather_dims.append(dim) + + for dim in gather_dims: + dim_partition_dict_for_input.pop(dim) + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + + if int_index: + shift_dim_partition_dict_for_output = {} + for dim, mesh_dim_list in dim_partition_dict_for_output.items(): + shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list + dim_partition_dict_for_output = shift_dim_partition_dict_for_output + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + except ShardingSpecException as e: + logger.debug(e) + continue strategy_list.append(strategy) for strategy in strategy_list: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index c5012934c..3547767dc 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,59 +1,83 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy class GetItemFromTensorModel(nn.Module): - def __init__(self): + def __init__(self, getitem_index): super().__init__() + self.getitem_index = getitem_index def forward(self, input, other): - conv_node = nn.functional.conv2d(input, other) - x = conv_node[1] + linear_node = nn.functional.linear(input, other, bias=None) + x = linear_node[self.getitem_index] return x -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_getitem_from_tensor_handler(): - model = GetItemFromTensorModel() +def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + model = GetItemFromTensorModel(getitem_index=getitem_index) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() - # graph(): - # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] - # %other : torch.Tensor [#users=1] = placeholder[target=other] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) - # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%conv2d, 1), kwargs = {}) - # return getitem + graph = tracer.trace(model, meta_args={ - "input": torch.rand(4, 4, 64, 64).to('meta'), - "other": torch.rand(4, 16, 3, 3).to('meta'), + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), }) - gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - conv_mod_node = list(graph.nodes)[2] + gm = ColoGraphModule(model, graph) + linear_mod_node = list(graph.nodes)[2] getitem_mod_node = list(graph.nodes)[3] getitem_strategies_vector = StrategiesVector(getitem_mod_node) - conv_strategies_vector = StrategiesVector(conv_mod_node) + linear_strategies_vector = StrategiesVector(linear_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) - conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + linear_handler = LinearFunctionHandler(node=linear_mod_node, + device_mesh=device_mesh, + strategies_vector=linear_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) getitem_handler = GetItemHandler(node=getitem_mod_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector) @@ -67,23 +91,22 @@ def test_getitem_from_tensor_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) - - assert mapping['index'].name == "index" - assert isinstance(mapping['index'].data, int) - assert mapping['index'].type == OperationDataType.ARG + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(linear_strategies_vector) - assert mapping['output'].name == "getitem" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 62, 62]) - assert mapping['output'].type == OperationDataType.OUTPUT - # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. - assert len(getitem_strategies_vector) == len(conv_strategies_vector) +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) +@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +def test_getitem_from_tensor_handler(getitem_index): + world_size = 4 + run_func = partial(check_getitem_from_tensor_handler, + getitem_index=getitem_index, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) class GetItemFromTupleModel(nn.Module):