mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Patch meta information of `torch.matmul` (#2584)
* [autoparallel] matmul metainfo * [auto_parallel] remove unused print * [tests] skip test_matmul_handler when torch version is lower than 1.12.0pull/2581/head
parent
4ae02c4b1c
commit
90a9fdd91d
|
@ -1,3 +1,4 @@
|
|||
from functools import reduce
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -16,7 +17,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info']
|
||||
__all__ = ['linear_meta_info', 'matmul_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
|
@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
||||
@meta_register.register(torch.matmul)
|
||||
def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.matmul meta info generator
|
||||
There are several cases for torch.matmul:
|
||||
1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
|
||||
as two input vectors.
|
||||
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
|
||||
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
|
||||
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
|
||||
the same size as the input matrix, and allocate memory for the gradient of two inputs.
|
||||
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
|
||||
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
|
||||
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
|
||||
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
|
||||
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
|
||||
matrix will be stored in the memory allocated during the forward phase.
|
||||
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
|
||||
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
|
||||
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
|
||||
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
|
||||
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
|
||||
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
|
||||
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
|
||||
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
|
||||
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
|
||||
|
||||
"""
|
||||
# Get input and output tensors
|
||||
input_tensors = [args[0].data, args[1].data]
|
||||
output_tensors = [args[-1].data]
|
||||
|
||||
# Check dimension
|
||||
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||
# Dot
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||
# gemv case 1: matrix-vector multiplication
|
||||
# &
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
# gemm & batched gemm case 1: batched matrix-matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
[input_tensors[1]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
|
||||
0, 1)
|
||||
], [output_tensors[0].transpose(-2, -1)])
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
[input_tensors[0]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
||||
temp=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
# Fetch shape of the two inputs and see if the batch dimensions are the same
|
||||
_is_batch_dims_same = True
|
||||
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
|
||||
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
if shape_0 != shape_1:
|
||||
_is_batch_dims_same = False
|
||||
break
|
||||
else:
|
||||
_is_batch_dims_same = False
|
||||
|
||||
# retireve dimensions
|
||||
input_dim_00 = input_tensors[0].shape[-2]
|
||||
input_dim_01 = input_tensors[0].shape[-1]
|
||||
input_dim_10 = input_tensors[1].shape[-2]
|
||||
input_dim_11 = input_tensors[1].shape[-1]
|
||||
output_dim_0 = output_tensors[0].shape[-2]
|
||||
output_dim_1 = output_tensors[0].shape[-1]
|
||||
|
||||
if _is_batch_dims_same:
|
||||
# Case 1: batch dimensions are the same
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
|
||||
-1, input_dim_10, input_dim_11)
|
||||
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
||||
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
batch_dims = output_tensors[0].shape[:-2]
|
||||
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_00,
|
||||
input_dim_01,
|
||||
device="meta")
|
||||
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_10,
|
||||
input_dim_11,
|
||||
device="meta")
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
||||
activation_size([extended_input_0, extended_input_1]),
|
||||
temp=activation_size([extended_input_0, extended_input_1]))
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# memory cost
|
||||
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = input_tensors
|
||||
fwd_buffer = []
|
||||
fwd_out = output_tensors
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
|
|||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (
|
||||
BatchedMatMulStrategyGenerator,
|
||||
|
@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
|||
|
||||
@operator_registry.register(torch.matmul)
|
||||
@operator_registry.register(torch.Tensor.matmul)
|
||||
class MatMulHandler(NodeHandler):
|
||||
class MatMulHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
||||
from .strategy import StrategyGenerator
|
||||
|
@ -266,6 +267,10 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'The target function {target} is not patched yet, ')
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
|
@ -317,4 +322,8 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'The target function {target} is not patched yet')
|
||||
|
||||
return self.strategies_vector
|
||||
|
|
|
@ -20,7 +20,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|||
# Inputs contains the shapes of two matrices.
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
assert len(input_shapes) == 2, input_shapes
|
||||
|
||||
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||
if all(len(shape) == 2 for shape in input_shapes):
|
||||
# gemm
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
elif all(len(shape) == 1 for shape in input_shapes):
|
||||
# dot
|
||||
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||
|
||||
# expand shape
|
||||
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||
else:
|
||||
# gemv
|
||||
if len(input_shapes[0]) == 1:
|
||||
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||
input_shapes.reverse()
|
||||
else:
|
||||
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||
|
||||
# expand the shape of the vector to [batch size, 1]
|
||||
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||
return flops
|
||||
|
||||
|
@ -204,8 +225,10 @@ def zero_flop_jit(*args):
|
|||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
# gemm, gemv and dot
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.mv.default: matmul_flop_jit,
|
||||
aten.dot.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
[[128], [128]], # dot product
|
||||
[[64, 128], [128]], # mat-vec
|
||||
[[128], [128, 64]], # vec-mat
|
||||
[[64, 64, 128], [128]], # batched mat-vec
|
||||
[[128], [64, 128, 64]], # vec-batched mat
|
||||
[[64, 128], [128, 192]], # mat-mat
|
||||
[[64, 64, 128], [128, 192]], # batched mat-mat
|
||||
[[64, 128], [64, 128, 192]], # mat-batched mat
|
||||
[[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims)
|
||||
[[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims)
|
||||
])
|
||||
def test_matmul_function_meta_info(tensor_shapes):
|
||||
meta_func = meta_register.get(torch.matmul)
|
||||
|
||||
# construct meta tensors
|
||||
input_tensor = torch.rand(*tensor_shapes[0], device="meta")
|
||||
other_tensor = torch.rand(*tensor_shapes[1], device="meta")
|
||||
output_tensor = torch.matmul(input_tensor, other_tensor)
|
||||
|
||||
# construct operation data
|
||||
input_data = OperationData(
|
||||
name="input",
|
||||
data=input_tensor,
|
||||
type=OperationDataType.ARG,
|
||||
logical_shape=input_tensor.shape,
|
||||
)
|
||||
other_data = OperationData(
|
||||
name="other",
|
||||
data=other_tensor,
|
||||
type=OperationDataType.ARG,
|
||||
logical_shape=other_tensor.shape,
|
||||
)
|
||||
output_data = OperationData(
|
||||
name="output",
|
||||
data=output_tensor,
|
||||
type=OperationDataType.OUTPUT,
|
||||
logical_shape=output_tensor.shape,
|
||||
)
|
||||
|
||||
# construct args and kwargs
|
||||
args = [input_data, other_data, output_data]
|
||||
kwargs = {'inplace': False}
|
||||
|
||||
# estimated results
|
||||
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
|
||||
|
||||
# actual results
|
||||
input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0")
|
||||
other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0")
|
||||
|
||||
input_real_tensor.requires_grad = True
|
||||
other_real_tensor.requires_grad = True
|
||||
|
||||
# fwd
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_stamp0 = torch.cuda.memory_allocated()
|
||||
output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor)
|
||||
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
|
||||
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
|
||||
|
||||
# bwd
|
||||
upstream_grad = torch.rand_like(output_real_tensor)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_stamp0 = torch.cuda.memory_allocated()
|
||||
torch.autograd.backward(output_real_tensor, upstream_grad)
|
||||
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
|
||||
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
|
||||
|
||||
compute_cost: TrainCycleItem
|
||||
memory_cost: TrainCycleItem
|
||||
|
||||
print("=====================")
|
||||
print(f"input shapes: {tensor_shapes[0]}, {tensor_shapes[1]}")
|
||||
print(f"output shapes: {output_tensor.shape}")
|
||||
|
||||
# estimated results
|
||||
print("Estimated Results")
|
||||
|
||||
# compute cost
|
||||
print("compute_cost:")
|
||||
print(f" fwd: {compute_cost.fwd}")
|
||||
print(f" bwd: {compute_cost.bwd}")
|
||||
|
||||
# memory cost
|
||||
print("memory_cost:")
|
||||
# fwd
|
||||
print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB")
|
||||
print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB")
|
||||
print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB")
|
||||
print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB")
|
||||
|
||||
# bwd
|
||||
print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB")
|
||||
print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB")
|
||||
print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB")
|
||||
print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB")
|
||||
|
||||
# actual results
|
||||
print("Actual Results")
|
||||
|
||||
print("memory_cost:")
|
||||
# fwd
|
||||
print(f" fwd allocated: {fwd_allocated / 1024} KB")
|
||||
print(f" fwd peak: {fwd_peak / 1024} KB")
|
||||
|
||||
# bwd
|
||||
print(f" bwd allocated: {bwd_allocated / 1024} KB")
|
||||
print(f" bwd peak: {bwd_peak / 1024} KB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_matmul_function_meta_info()
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -24,6 +25,7 @@ class MatMulModule(nn.Module):
|
|||
return torch.matmul(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue