mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix aten default bug (#2158)
parent
a4b4bb01d6
commit
16335cb537
|
@ -7,6 +7,7 @@ from numbers import Number
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
@ -188,131 +189,136 @@ def zero_flop_jit(*args):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
flop_mapping = {
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
|
flop_mapping = {
|
||||||
# gemm
|
# gemm
|
||||||
aten.mm.default: matmul_flop_jit,
|
aten.mm.default: matmul_flop_jit,
|
||||||
aten.matmul.default: matmul_flop_jit,
|
aten.matmul.default: matmul_flop_jit,
|
||||||
aten.addmm.default: addmm_flop_jit,
|
aten.addmm.default: addmm_flop_jit,
|
||||||
aten.bmm.default: bmm_flop_jit,
|
aten.bmm.default: bmm_flop_jit,
|
||||||
|
|
||||||
# convolution
|
# convolution
|
||||||
aten.convolution.default: conv_flop_jit,
|
aten.convolution.default: conv_flop_jit,
|
||||||
aten._convolution.default: conv_flop_jit,
|
aten._convolution.default: conv_flop_jit,
|
||||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||||
|
|
||||||
# normalization
|
# normalization
|
||||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
|
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
|
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
|
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
|
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
|
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
|
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
|
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
|
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
|
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
|
||||||
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
|
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
|
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.embedding.default: elementwise_flop_counter(1, 0),
|
aten.embedding.default: elementwise_flop_counter(1, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
elementwise_flop_aten = [
|
elementwise_flop_aten = [
|
||||||
# basic op
|
# basic op
|
||||||
aten.add.Tensor,
|
aten.add.Tensor,
|
||||||
aten.add_.Tensor,
|
aten.add_.Tensor,
|
||||||
aten.div.Tensor,
|
aten.div.Tensor,
|
||||||
aten.div_.Tensor,
|
aten.div_.Tensor,
|
||||||
aten.div.Scalar,
|
aten.div.Scalar,
|
||||||
aten.div_.Scalar,
|
aten.div_.Scalar,
|
||||||
aten.mul.Tensor,
|
aten.mul.Tensor,
|
||||||
aten.mul.Scalar,
|
aten.mul.Scalar,
|
||||||
aten.mul_.Tensor,
|
aten.mul_.Tensor,
|
||||||
aten.neg.default,
|
aten.neg.default,
|
||||||
aten.pow.Tensor_Scalar,
|
aten.pow.Tensor_Scalar,
|
||||||
aten.rsub.Scalar,
|
aten.rsub.Scalar,
|
||||||
aten.sum.default,
|
aten.sum.default,
|
||||||
aten.sum.dim_IntList,
|
aten.sum.dim_IntList,
|
||||||
aten.mean.dim,
|
aten.mean.dim,
|
||||||
|
|
||||||
# activation op
|
# activation op
|
||||||
aten.hardswish.default,
|
aten.hardswish.default,
|
||||||
aten.hardswish_.default,
|
aten.hardswish_.default,
|
||||||
aten.hardswish_backward.default,
|
aten.hardswish_backward.default,
|
||||||
aten.hardtanh.default,
|
aten.hardtanh.default,
|
||||||
aten.hardtanh_.default,
|
aten.hardtanh_.default,
|
||||||
aten.hardtanh_backward.default,
|
aten.hardtanh_backward.default,
|
||||||
aten.hardsigmoid_backward.default,
|
aten.hardsigmoid_backward.default,
|
||||||
aten.hardsigmoid.default,
|
aten.hardsigmoid.default,
|
||||||
aten.gelu.default,
|
aten.gelu.default,
|
||||||
aten.gelu_backward.default,
|
aten.gelu_backward.default,
|
||||||
aten.silu.default,
|
aten.silu.default,
|
||||||
aten.silu_.default,
|
aten.silu_.default,
|
||||||
aten.silu_backward.default,
|
aten.silu_backward.default,
|
||||||
aten.sigmoid.default,
|
aten.sigmoid.default,
|
||||||
aten.sigmoid_backward.default,
|
aten.sigmoid_backward.default,
|
||||||
aten._softmax.default,
|
aten._softmax.default,
|
||||||
aten._softmax_backward_data.default,
|
aten._softmax_backward_data.default,
|
||||||
aten.relu_.default,
|
aten.relu_.default,
|
||||||
aten.relu.default,
|
aten.relu.default,
|
||||||
aten.tanh.default,
|
aten.tanh.default,
|
||||||
aten.tanh_backward.default,
|
aten.tanh_backward.default,
|
||||||
aten.threshold_backward.default,
|
aten.threshold_backward.default,
|
||||||
|
|
||||||
# dropout
|
# dropout
|
||||||
aten.native_dropout.default,
|
aten.native_dropout.default,
|
||||||
aten.native_dropout_backward.default,
|
aten.native_dropout_backward.default,
|
||||||
]
|
]
|
||||||
|
for op in elementwise_flop_aten:
|
||||||
|
flop_mapping[op] = elementwise_flop_counter(1, 0)
|
||||||
|
|
||||||
for op in elementwise_flop_aten:
|
# TODO: this will be removed in future
|
||||||
flop_mapping[op] = elementwise_flop_counter(1, 0)
|
zero_flop_aten = [
|
||||||
|
aten.as_strided.default,
|
||||||
|
aten.as_strided_.default,
|
||||||
|
aten.bernoulli_.float,
|
||||||
|
aten.cat.default,
|
||||||
|
aten.clone.default,
|
||||||
|
aten.copy_.default,
|
||||||
|
aten.detach.default,
|
||||||
|
aten.expand.default,
|
||||||
|
aten.empty_like.default,
|
||||||
|
aten.new_empty.default,
|
||||||
|
aten.new_empty_strided.default,
|
||||||
|
aten.ones_like.default,
|
||||||
|
aten._reshape_alias.default,
|
||||||
|
aten.select.int,
|
||||||
|
aten.select_backward.default,
|
||||||
|
aten.squeeze.dim,
|
||||||
|
aten.slice.Tensor,
|
||||||
|
aten.slice_backward.default,
|
||||||
|
aten.split.Tensor,
|
||||||
|
aten.permute.default,
|
||||||
|
aten.t.default,
|
||||||
|
aten.transpose.int,
|
||||||
|
aten._to_copy.default,
|
||||||
|
aten.unsqueeze.default,
|
||||||
|
aten.unbind.int,
|
||||||
|
aten._unsafe_view.default,
|
||||||
|
aten.view.default,
|
||||||
|
aten.where.self,
|
||||||
|
aten.zero_.default,
|
||||||
|
aten.zeros_like.default,
|
||||||
|
]
|
||||||
|
|
||||||
# TODO: this will be removed in future
|
for op in zero_flop_aten:
|
||||||
zero_flop_aten = [
|
flop_mapping[op] = zero_flop_jit
|
||||||
aten.as_strided.default,
|
|
||||||
aten.as_strided_.default,
|
|
||||||
aten.bernoulli_.float,
|
|
||||||
aten.cat.default,
|
|
||||||
aten.clone.default,
|
|
||||||
aten.copy_.default,
|
|
||||||
aten.detach.default,
|
|
||||||
aten.expand.default,
|
|
||||||
aten.empty_like.default,
|
|
||||||
aten.new_empty.default,
|
|
||||||
aten.new_empty_strided.default,
|
|
||||||
aten.ones_like.default,
|
|
||||||
aten._reshape_alias.default,
|
|
||||||
aten.select.int,
|
|
||||||
aten.select_backward.default,
|
|
||||||
aten.squeeze.dim,
|
|
||||||
aten.slice.Tensor,
|
|
||||||
aten.slice_backward.default,
|
|
||||||
aten.split.Tensor,
|
|
||||||
aten.permute.default,
|
|
||||||
aten.t.default,
|
|
||||||
aten.transpose.int,
|
|
||||||
aten._to_copy.default,
|
|
||||||
aten.unsqueeze.default,
|
|
||||||
aten.unbind.int,
|
|
||||||
aten._unsafe_view.default,
|
|
||||||
aten.view.default,
|
|
||||||
aten.where.self,
|
|
||||||
aten.zero_.default,
|
|
||||||
aten.zeros_like.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
for op in zero_flop_aten:
|
else:
|
||||||
flop_mapping[op] = zero_flop_jit
|
flop_mapping = {}
|
||||||
|
elementwise_flop_aten = {}
|
||||||
|
zero_flop_aten = {}
|
||||||
|
|
|
@ -207,9 +207,9 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
||||||
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
|
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@parameterize('op', [torch.add])
|
@parameterize('op', [torch.add])
|
||||||
@parameterize('other_dim', [1, 2])
|
@parameterize('other_dim', [1, 2])
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_binary_elementwise_handler(op, other_dim):
|
def test_binary_elementwise_handler(op, other_dim):
|
||||||
|
|
|
@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
|
||||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_bmm_handler(module):
|
def test_bmm_handler(module):
|
||||||
|
|
|
@ -23,6 +23,7 @@ class GetItemFromTensorModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_getitem_from_tensor_handler():
|
def test_getitem_from_tensor_handler():
|
||||||
model = GetItemFromTensorModel()
|
model = GetItemFromTensorModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
@ -96,6 +97,7 @@ class GetItemFromTupleModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_getitem_from_tuple_handler():
|
def test_getitem_from_tuple_handler():
|
||||||
model = GetItemFromTupleModel()
|
model = GetItemFromTupleModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_linear_handler(input_shape, bias=False):
|
def test_linear_handler(input_shape, bias=False):
|
||||||
|
|
|
@ -2,15 +2,15 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
|
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||||
NormPoolingHandler
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_norm_pool_handler():
|
def test_norm_pool_handler():
|
||||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -20,6 +20,7 @@ class ReshapeModel(nn.Module):
|
||||||
return reshape_node
|
return reshape_node
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_reshape_handler():
|
def test_reshape_handler():
|
||||||
model = ReshapeModel()
|
model = ReshapeModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handl
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
class TensorConstructorModel(nn.Module):
|
class TensorConstructorModel(nn.Module):
|
||||||
|
@ -18,6 +19,7 @@ class TensorConstructorModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_where_handler():
|
def test_where_handler():
|
||||||
model = TensorConstructorModel()
|
model = TensorConstructorModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
|
||||||
return relu_node
|
return relu_node
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_elementwise_handler():
|
def test_elementwise_handler():
|
||||||
model = ReLuModel()
|
model = ReLuModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
)
|
)
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
def _param_resharding_cost_assertion(node):
|
def _param_resharding_cost_assertion(node):
|
||||||
|
@ -51,6 +52,7 @@ class ConvModel(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_linear_module():
|
def test_linear_module():
|
||||||
model = LinearModel(4, 8)
|
model = LinearModel(4, 8)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
@ -86,6 +88,7 @@ def test_linear_module():
|
||||||
_param_resharding_cost_assertion(linear_node)
|
_param_resharding_cost_assertion(linear_node)
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
def test_conv_module():
|
def test_conv_module():
|
||||||
model = ConvModel(3, 6, 2)
|
model = ConvModel(3, 6, 2)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
Loading…
Reference in New Issue