diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 61f8fdff3..617375721 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -1,3 +1,4 @@ +from functools import reduce from typing import Callable, Dict, List, Tuple, Union import torch @@ -16,7 +17,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register -__all__ = ['linear_meta_info'] +__all__ = ['linear_meta_info', 'matmul_meta_info'] @meta_register.register(torch.nn.functional.linear) @@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L fwd_out = [torch.zeros_like(output_tensor, device='meta')] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.matmul) +def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.matmul meta info generator + There are several cases for torch.matmul: + 1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same + as two input vectors. + 2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward + phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if + the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory + the same size as the input matrix, and allocate memory for the gradient of two inputs. + 3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of + output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for + the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is + the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will + allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched + matrix will be stored in the memory allocated during the forward phase. + 3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs + 4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two + inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate + memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it + will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input. + 5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size + of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate + memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of + two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase. + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs + + """ + # Get input and output tensors + input_tensors = [args[0].data, args[1].data] + output_tensors = [args[-1].data] + + # Check dimension + if all(len(tensor.shape) == 1 for tensor in input_tensors): + # Dot + fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors) + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2 + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1: + # gemv case 1: matrix-vector multiplication + # & + # batched gemv case 1: batched matrix-vector multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) + + # combine the dimensions of output + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0].reshape(-1), input_tensors[1]], + output_tensors) + \ + flop_mapping[torch.ops.aten.mv.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2: + # gemv case 2: vector-matrix multiplication + fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ + flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), + parameter=0, + temp=activation_size(input_tensors[1]), + buffer=0) + + elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: + # batched gemv case 2: vector-batched matrix multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], + [output_tensors[0].reshape(-1)]) + + # combine the dimensions of output + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0].reshape(-1), input_tensors[0]], + output_tensors + ) + \ + flop_mapping[torch.ops.aten.mv.default]( + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors + ) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]])) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + parameter=0, + temp=activation_size(input_tensors[1]), + buffer=0) + + elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: + # gemm & batched gemm case 1: batched matrix-matrix multiplication + + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], + [input_tensors[1]] + ) + \ + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] + ) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + + elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: + # batched gemm case 2: matrix-batched matrix multiplication + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( + 0, 1) + ], [output_tensors[0].transpose(-2, -1)]) + + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], + [input_tensors[0]] + ) + \ + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] + ) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]), + temp=activation_size(output_tensors)) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + parameter=0, + temp=activation_size(input_tensors[1]) + activation_size(output_tensors)) + + elif all(len(tensor.shape) >= 3 for tensor in input_tensors): + # Batched matrix-batched matrix multiplication + # Fetch shape of the two inputs and see if the batch dimensions are the same + _is_batch_dims_same = True + if len(input_tensors[0].shape) == len(input_tensors[1].shape): + for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): + if shape_0 != shape_1: + _is_batch_dims_same = False + break + else: + _is_batch_dims_same = False + + # retireve dimensions + input_dim_00 = input_tensors[0].shape[-2] + input_dim_01 = input_tensors[0].shape[-1] + input_dim_10 = input_tensors[1].shape[-2] + input_dim_11 = input_tensors[1].shape[-1] + output_dim_0 = output_tensors[0].shape[-2] + output_dim_1 = output_tensors[0].shape[-1] + + if _is_batch_dims_same: + # Case 1: batch dimensions are the same + + # Forward compute cost: C = A * B + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ + input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( + -1, input_dim_10, input_dim_11) + ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + + # Backward compute cost: dB = A^T * dC, dA = dC * B^T + bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] + ) + \ + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], + [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] + ) + + fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors)) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors)) + + else: + # Case 2: batch dimensions are different + batch_dims = output_tensors[0].shape[:-2] + extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), + input_dim_00, + input_dim_01, + device="meta") + extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), + input_dim_10, + input_dim_11, + device="meta") + + # Forward compute cost: C = A * B + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + + # Backward compute cost: dB = A^T * dC, dA = dC * B^T + bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [extended_input_1] + ) + \ + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], + [extended_input_0] + ) + + fwd_mem_cost = MemoryCost( + activation=activation_size([output_tensors[0], extended_input_0, extended_input_1])) + bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) - + activation_size([extended_input_0, extended_input_1]), + temp=activation_size([extended_input_0, extended_input_1])) + + # compute cost + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = input_tensors + fwd_buffer = [] + fwd_out = output_tensors + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index 131c35156..f3c9d0cbf 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ( BatchedMatMulStrategyGenerator, @@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms): @operator_registry.register(torch.matmul) @operator_registry.register(torch.Tensor.matmul) -class MatMulHandler(NodeHandler): +class MatMulHandler(MetaInfoNodeHandler): """ The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation. According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index fbab2b61e..c6f8d035a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -16,6 +16,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity from colossalai.device.device_mesh import DeviceMesh +from colossalai.logging import get_dist_logger from colossalai.tensor.shape_consistency import ShapeConsistencyManager from .strategy import StrategyGenerator @@ -266,6 +267,10 @@ class MetaInfoNodeHandler(NodeHandler): # attach metainfos to the handler setattr(self, "metainfo_vector", metainfo_vector) + else: + logger = get_dist_logger() + logger.warning(f'The target function {target} is not patched yet, ') + return self.strategies_vector @@ -317,4 +322,8 @@ class MetaInfoModuleHandler(ModuleHandler): # attach metainfos to the handler setattr(self, "metainfo_vector", metainfo_vector) + else: + logger = get_dist_logger() + logger.warning(f'The target function {target} is not patched yet') + return self.strategies_vector diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index d780ef6d4..6bdec865f 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -20,7 +20,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs contains the shapes of two matrices. input_shapes = [v.shape for v in inputs] assert len(input_shapes) == 2, input_shapes - assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + + # There are three cases: 1) gemm, 2) gemv, 3) dot + if all(len(shape) == 2 for shape in input_shapes): + # gemm + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + elif all(len(shape) == 1 for shape in input_shapes): + # dot + assert input_shapes[0][0] == input_shapes[1][0], input_shapes + + # expand shape + input_shapes[0] = torch.Size([1, input_shapes[0][0]]) + input_shapes[1] = torch.Size([input_shapes[1][0], 1]) + else: + # gemv + if len(input_shapes[0]) == 1: + assert input_shapes[0][0] == input_shapes[1][-2], input_shapes + input_shapes.reverse() + else: + assert input_shapes[1][0] == input_shapes[0][-1], input_shapes + + # expand the shape of the vector to [batch size, 1] + input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] return flops @@ -204,8 +225,10 @@ def zero_flop_jit(*args): if version.parse(torch.__version__) >= version.parse('1.12.0'): flop_mapping = { - # gemm + # gemm, gemv and dot aten.mm.default: matmul_flop_jit, + aten.mv.default: matmul_flop_jit, + aten.dot.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py new file mode 100644 index 000000000..3fb9c3d85 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -0,0 +1,145 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@parameterize( + 'tensor_shapes', + [ + [[128], [128]], # dot product + [[64, 128], [128]], # mat-vec + [[128], [128, 64]], # vec-mat + [[64, 64, 128], [128]], # batched mat-vec + [[128], [64, 128, 64]], # vec-batched mat + [[64, 128], [128, 192]], # mat-mat + [[64, 64, 128], [128, 192]], # batched mat-mat + [[64, 128], [64, 128, 192]], # mat-batched mat + [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) + [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) + ]) +def test_matmul_function_meta_info(tensor_shapes): + meta_func = meta_register.get(torch.matmul) + + # construct meta tensors + input_tensor = torch.rand(*tensor_shapes[0], device="meta") + other_tensor = torch.rand(*tensor_shapes[1], device="meta") + output_tensor = torch.matmul(input_tensor, other_tensor) + + # construct operation data + input_data = OperationData( + name="input", + data=input_tensor, + type=OperationDataType.ARG, + logical_shape=input_tensor.shape, + ) + other_data = OperationData( + name="other", + data=other_tensor, + type=OperationDataType.ARG, + logical_shape=other_tensor.shape, + ) + output_data = OperationData( + name="output", + data=output_tensor, + type=OperationDataType.OUTPUT, + logical_shape=output_tensor.shape, + ) + + # construct args and kwargs + args = [input_data, other_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0") + other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0") + + input_real_tensor.requires_grad = True + other_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + compute_cost: TrainCycleItem + memory_cost: TrainCycleItem + + print("=====================") + print(f"input shapes: {tensor_shapes[0]}, {tensor_shapes[1]}") + print(f"output shapes: {output_tensor.shape}") + + # estimated results + print("Estimated Results") + + # compute cost + print("compute_cost:") + print(f" fwd: {compute_cost.fwd}") + print(f" bwd: {compute_cost.bwd}") + + # memory cost + print("memory_cost:") + # fwd + print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB") + print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB") + print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB") + print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB") + + # bwd + print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB") + print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB") + print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB") + print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB") + + # actual results + print("Actual Results") + + print("memory_cost:") + # fwd + print(f" fwd allocated: {fwd_allocated / 1024} KB") + print(f" fwd peak: {fwd_peak / 1024} KB") + + # bwd + print(f" bwd allocated: {bwd_allocated / 1024} KB") + print(f" bwd peak: {bwd_peak / 1024} KB") + + +if __name__ == '__main__': + test_matmul_function_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 306c45f56..91b3ae27d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.nn as nn @@ -24,6 +25,7 @@ class MatMulModule(nn.Module): return torch.matmul(x1, x2) +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @parameterize( 'tensor_shapes', [