diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 8e38d34ca..d8dbaa0ac 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,7 +1,8 @@ from .batch_norm_handler import BatchNormModuleHandler +from .bmm_handler import BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler -from .dot_handler import LinearFunctionHandler, LinearModuleHandler from .layer_norm_handler import LayerNormModuleHandler +from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .normal_pooling_handler import NormPoolingHandler from .output_handler import OuputHandler from .placeholder_handler import PlacehodlerHandler @@ -11,7 +12,7 @@ from .unary_elementwise_handler import UnaryElementwiseHandler from .where_handler import WhereHandler __all__ = [ - 'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', - 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', - 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry' + 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'LayerNormModuleHandler', + 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', + 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py new file mode 100644 index 000000000..a1ca06a74 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -0,0 +1,33 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator + + +@operator_registry.register(torch.bmm) +@operator_registry.register(torch.Tensor.bmm) +class BMMFunctionHandler(NodeHandler): + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.node.args[1]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + generators = [] + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py similarity index 86% rename from colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 3de03f440..4a8af4ca7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -1,19 +1,18 @@ -from copy import deepcopy from typing import Dict, List, Union import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim) +from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from .node_handler import ModuleHandler, NodeHandler from .registry import operator_registry -from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator) +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, @@ -31,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr op_data = strategy.get_op_data_by_name(weight_name) assert op_data.logical_shape != op_data.data.shape, \ "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" - switch_partition_dim(sharding_spec, 0, -1) + tranpose_partition_dim(sharding_spec, 0, -1) return strategy @@ -104,8 +103,6 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha dim_mapping={}, physical_shape=output_op_data.data.shape, inplace=True) - print(input_op_data.data.shape) - print(output_op_data.data.shape) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -144,7 +141,7 @@ class LinearModuleHandler(ModuleHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.named_parameters['bias'] is not None: + if 'bias' in self.named_parameters is not None: physical_bias_operand = OperationData(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']) @@ -229,30 +226,3 @@ class LinearFunctionHandler(NodeHandler): input_name=str(self.node.args[0]), output_name=str(self.node)) return strategies - - -@operator_registry.register(torch.bmm) -@operator_registry.register(torch.Tensor.bmm) -class BMMFunctionHandler(NodeHandler): - - def get_operation_data_mapping(self) -> Dict[str, OperationData]: - # use transposed shape for strategies - # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.node.args[1]._meta_data) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) - - mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - return mapping - - def get_strategy_generator(self) -> List[StrategyGenerator]: - generators = [] - op_data_mapping = self.get_operation_data_mapping() - generators = [] - generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) - return generators diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 175ef6631..d178ebe7a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -2,9 +2,8 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import \ - ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -227,6 +226,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action + communication_action_mapping = {} input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, @@ -235,12 +235,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0) - bias_comm_spec = self.get_communication_spec( - sharding_spec_mapping["bias"], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec} + communication_action_mapping['input'] = input_comm_spec + communication_action_mapping['other'] = other_comm_spec + + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping['bias'] = bias_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -268,6 +272,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action mapping + communication_action_mapping = {} input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, @@ -276,12 +281,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["bias"], - communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) - communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec} + communication_action_mapping['input'] = input_comm_spec + communication_action_mapping['output'] = output_comm_spec + + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1) + communication_action_mapping['bias'] = bias_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -310,6 +319,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions + communication_action_mapping = {} output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, @@ -318,7 +328,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) - communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec} + communication_action_mapping["input"] = input_comm_spec + communication_action_mapping['output'] = output_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -342,11 +353,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action + communication_action_mapping = {} output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim) - communication_action_mapping = {'output': output_comm_spec} + + communication_action_mapping['output'] = output_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -372,11 +385,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions + communication_action_mapping = {} input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim) - communication_action_mapping = {'input': input_comm_spec} + + communication_action_mapping['input'] = input_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -398,19 +413,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action + communication_action_mapping = {} other_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping['other'] = other_comm_spec - communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec} + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping['bias'] = bias_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communcation_action_mapping) + communication_action_mapping=communication_action_mapping) @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): @@ -430,11 +448,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action + communication_action_mapping = {} output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping = {'output': output_comm_spec} + communication_action_mapping['output'] = output_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -460,11 +479,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action + communication_action_mapping = {} input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping = {'input': input_comm_spec} + communication_action_mapping['input'] = input_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -492,7 +512,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): """ Generate sharding strategies for the batched matrix multiplication. - A batched matrix multiplication can be viewed as + A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] """ @@ -642,7 +662,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): "bias": {}, "output": { 0: [mesh_dim_0], - -2: [mesh_dim_1] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index c570ac871..9032fc58f 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -5,13 +5,13 @@ from .sharding import ( enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, - switch_partition_dim, + tranpose_partition_dim, update_partition_dim, ) __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' - 'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 9a445869f..c0ef6df88 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -36,9 +36,10 @@ def ignore_sharding_exception(func): def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor): """ This function checks whether the ShardingSpec is valid for the physical tensor. - This check includes 2 items: + This check includes 3 items: 1. the sharding spec covers all dimensions of the physical tensor 2. the sharding spec for each dimension is divisible by the number of devices. + 3. the sharding spec's entire shape must match the tensor shape # """ # make sure all dims are covered in sharding spec @@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens assert dim_size >= num_devices and dim_size % num_devices == 0, \ f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + + # make sure the entire shape matches the physical tensor shape + assert sharding_spec.entire_shape == tensor.shape diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py index ae5d250a4..622a33367 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -8,12 +8,12 @@ import torch from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] -def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: +def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: """ Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. @@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S dim1 (int): the tensor dimension to switch dim2 (int): the tensor dimension to switch """ - assert len(sharding_spec.entire_shape) == 2 + assert len(sharding_spec.entire_shape) >= 2, \ + 'The entire_shape of the sharding spec must have at least 2 dimensions' dim_partition_dict = sharding_spec.dim_partition_dict + + # transpose the dim partition dim1_partition = dim_partition_dict.pop(dim1, None) dim2_partition = dim_partition_dict.pop(dim2, None) if dim1_partition: dim_partition_dict[dim2] = dim1_partition - if dim2_partition: dim_partition_dict[dim1] = dim2_partition + # get the transposed shape + new_shape = list(sharding_spec.entire_shape[:]) + new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2] + new_shape = torch.Size(new_shape) + # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict) return sharding_spec diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 76cbe6bd5..5d7272bec 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -2,12 +2,10 @@ import pytest import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \ - BMMFunctionHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag class BMMTensorMethodModule(nn.Module): @@ -91,6 +89,16 @@ def test_2d_device_mesh(module): assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_1d_device_mesh(module): @@ -145,6 +153,16 @@ def test_1d_device_mesh(module): # one batch dim assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if __name__ == '__main__': test_1d_device_mesh(BMMTensorMethodModule) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index b2d6754a5..69fd411e0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -5,8 +5,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag +@run_on_environment_flag(name='AUTO_PARALLEL') def test_conv_module_handler(): model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta')) tracer = ColoTracer() @@ -108,6 +110,7 @@ class ConvModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_conv_function_handler(): model = ConvModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 37a612de1..d185eb6db 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,14 +1,13 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ - ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \ - GetItemHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.testing.pytest_wrapper import run_on_environment_flag class GetItemModel(nn.Module): @@ -22,6 +21,7 @@ class GetItemModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_getitem_function_handler(): model = GetItemModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index d2f26e704..290d73f5a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn +from typing_extensions import Self -from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -10,10 +11,12 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.utils import parameterize -def test_linear_module_handler(): - model = nn.Sequential(nn.Linear(16, 32).to('meta')) +@parameterize('bias', [True, False]) +def test_linear_module_handler(bias): + model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta')) tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) @@ -50,11 +53,12 @@ def test_linear_module_handler(): assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([16, 32]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([32]) + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([32]) assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta @@ -91,18 +95,23 @@ def test_linear_module_handler(): strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -def test_linear_function_handler(): - model = nn.Linear(16, 32).to('meta') +@parameterize('bias', [True, False]) +def test_linear_function_handler(bias): + model = nn.Linear(16, 32, bias=bias).to('meta') tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) @@ -111,7 +120,11 @@ def test_linear_function_handler(): print(graph) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - linear_func_node = list(graph.nodes)[3] + + if bias: + linear_func_node = list(graph.nodes)[3] + else: + linear_func_node = list(graph.nodes)[2] strategies_vector = StrategiesVector(linear_func_node) # build handler @@ -120,8 +133,6 @@ def test_linear_function_handler(): # # check operation data mapping mapping = handler.get_operation_data_mapping() - print(mapping['input'].logical_shape) - assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) @@ -134,11 +145,12 @@ def test_linear_function_handler(): assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([16, 32]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([32]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "linear" assert mapping['output'].data.is_meta @@ -172,17 +184,20 @@ def test_linear_function_handler(): for strategy in strategies_vector: strategy: ShardingStrategy - print(strategy) input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] - assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py index b35fc64b6..de277002b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -1,13 +1,12 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ - ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \ - ReshapeHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ReshapeModel(nn.Module): @@ -21,6 +20,7 @@ class ReshapeModel(nn.Module): return reshape_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_reshape_handler(): model = ReshapeModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index b27c0e412..a861cb7f5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -1,13 +1,13 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ - ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \ - UnaryElementwiseHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear +from colossalai.testing.pytest_wrapper import run_on_environment_flag class ReLuModel(nn.Module): @@ -22,6 +22,7 @@ class ReLuModel(nn.Module): return relu_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_elementwise_handler(): model = ReLuModel() tracer = ColoTracer()