diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index bb8db54a4..1c39dc247 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -7,6 +7,7 @@ from numbers import Number from typing import Any, Callable, List import torch +from packaging import version aten = torch.ops.aten @@ -188,131 +189,136 @@ def zero_flop_jit(*args): return 0 -flop_mapping = { +if version.parse(torch.__version__) >= version.parse('1.12.0'): + flop_mapping = { # gemm - aten.mm.default: matmul_flop_jit, - aten.matmul.default: matmul_flop_jit, - aten.addmm.default: addmm_flop_jit, - aten.bmm.default: bmm_flop_jit, + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, # convolution - aten.convolution.default: conv_flop_jit, - aten._convolution.default: conv_flop_jit, - aten.convolution_backward.default: conv_backward_flop_jit, + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, # normalization - aten.native_batch_norm.default: batchnorm_flop_jit, - aten.native_batch_norm_backward.default: batchnorm_flop_jit, - aten.cudnn_batch_norm.default: batchnorm_flop_jit, - aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), - aten.native_layer_norm.default: norm_flop_counter(2, 0), - aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), # pooling - aten.avg_pool1d.default: elementwise_flop_counter(1, 0), - aten.avg_pool2d.default: elementwise_flop_counter(1, 0), - aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), - aten.avg_pool3d.default: elementwise_flop_counter(1, 0), - aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), - aten.max_pool1d.default: elementwise_flop_counter(1, 0), - aten.max_pool2d.default: elementwise_flop_counter(1, 0), - aten.max_pool3d.default: elementwise_flop_counter(1, 0), - aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), - aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), - aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), - aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), - aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), - aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), - aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), - aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), - aten.embedding.default: elementwise_flop_counter(1, 0), -} + aten.avg_pool1d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d.default: elementwise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten.avg_pool3d.default: elementwise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool1d.default: elementwise_flop_counter(1, 0), + aten.max_pool2d.default: elementwise_flop_counter(1, 0), + aten.max_pool3d.default: elementwise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), + aten.embedding.default: elementwise_flop_counter(1, 0), + } -elementwise_flop_aten = [ + elementwise_flop_aten = [ # basic op - aten.add.Tensor, - aten.add_.Tensor, - aten.div.Tensor, - aten.div_.Tensor, - aten.div.Scalar, - aten.div_.Scalar, - aten.mul.Tensor, - aten.mul.Scalar, - aten.mul_.Tensor, - aten.neg.default, - aten.pow.Tensor_Scalar, - aten.rsub.Scalar, - aten.sum.default, - aten.sum.dim_IntList, - aten.mean.dim, + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, # activation op - aten.hardswish.default, - aten.hardswish_.default, - aten.hardswish_backward.default, - aten.hardtanh.default, - aten.hardtanh_.default, - aten.hardtanh_backward.default, - aten.hardsigmoid_backward.default, - aten.hardsigmoid.default, - aten.gelu.default, - aten.gelu_backward.default, - aten.silu.default, - aten.silu_.default, - aten.silu_backward.default, - aten.sigmoid.default, - aten.sigmoid_backward.default, - aten._softmax.default, - aten._softmax_backward_data.default, - aten.relu_.default, - aten.relu.default, - aten.tanh.default, - aten.tanh_backward.default, - aten.threshold_backward.default, + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, # dropout - aten.native_dropout.default, - aten.native_dropout_backward.default, -] + aten.native_dropout.default, + aten.native_dropout_backward.default, + ] + for op in elementwise_flop_aten: + flop_mapping[op] = elementwise_flop_counter(1, 0) -for op in elementwise_flop_aten: - flop_mapping[op] = elementwise_flop_counter(1, 0) + # TODO: this will be removed in future + zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.bernoulli_.float, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.where.self, + aten.zero_.default, + aten.zeros_like.default, + ] -# TODO: this will be removed in future -zero_flop_aten = [ - aten.as_strided.default, - aten.as_strided_.default, - aten.bernoulli_.float, - aten.cat.default, - aten.clone.default, - aten.copy_.default, - aten.detach.default, - aten.expand.default, - aten.empty_like.default, - aten.new_empty.default, - aten.new_empty_strided.default, - aten.ones_like.default, - aten._reshape_alias.default, - aten.select.int, - aten.select_backward.default, - aten.squeeze.dim, - aten.slice.Tensor, - aten.slice_backward.default, - aten.split.Tensor, - aten.permute.default, - aten.t.default, - aten.transpose.int, - aten._to_copy.default, - aten.unsqueeze.default, - aten.unbind.int, - aten._unsafe_view.default, - aten.view.default, - aten.where.self, - aten.zero_.default, - aten.zeros_like.default, -] + for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit -for op in zero_flop_aten: - flop_mapping[op] = zero_flop_jit +else: + flop_mapping = {} + elementwise_flop_aten = {} + zero_flop_aten = {} diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index cd9f79953..42430d5a2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -207,9 +207,9 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence +@run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('op', [torch.add]) @parameterize('other_dim', [1, 2]) -@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler(op, other_dim): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 778469df4..02c7e0671 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port): assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 4e01ed243..c5012934c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -23,6 +23,7 @@ class GetItemFromTensorModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_getitem_from_tensor_handler(): model = GetItemFromTensorModel() tracer = ColoTracer() @@ -96,6 +97,7 @@ class GetItemFromTupleModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index fb8821fae..3d268ea43 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) @run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index d47876af2..f219bc2f3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -2,15 +2,15 @@ import pytest import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \ - NormPoolingHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing.pytest_wrapper import run_on_environment_flag +@run_on_environment_flag(name='AUTO_PARALLEL') def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py index 613f8f3d0..de277002b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -20,6 +20,7 @@ class ReshapeModel(nn.Module): return reshape_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_reshape_handler(): model = ReshapeModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index 0c67abc7d..de35fe256 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -5,6 +5,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handl from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag class TensorConstructorModel(nn.Module): @@ -18,6 +19,7 @@ class TensorConstructorModel(nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_where_handler(): model = TensorConstructorModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index e4d12cd12..a861cb7f5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -22,6 +22,7 @@ class ReLuModel(nn.Module): return relu_node +@run_on_environment_flag(name='AUTO_PARALLEL') def test_elementwise_handler(): model = ReLuModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py index 611402fe8..b504d59c9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import ( ) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag def _param_resharding_cost_assertion(node): @@ -51,6 +52,7 @@ class ConvModel(torch.nn.Module): return x +@run_on_environment_flag(name='AUTO_PARALLEL') def test_linear_module(): model = LinearModel(4, 8) physical_mesh_id = torch.arange(0, 4) @@ -86,6 +88,7 @@ def test_linear_module(): _param_resharding_cost_assertion(linear_node) +@run_on_environment_flag(name='AUTO_PARALLEL') def test_conv_module(): model = ConvModel(3, 6, 2) physical_mesh_id = torch.arange(0, 4)