[autoparallel] Patch meta information of `torch.matmul` (#2584)

* [autoparallel] matmul metainfo

* [auto_parallel] remove unused print

* [tests] skip test_matmul_handler when torch version is lower than 1.12.0
pull/2581/head
Boyuan Yao 2023-02-08 11:05:31 +08:00 committed by GitHub
parent 4ae02c4b1c
commit 90a9fdd91d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 417 additions and 5 deletions

View File

@ -1,3 +1,4 @@
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
import torch
@ -16,7 +17,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['linear_meta_info']
__all__ = ['linear_meta_info', 'matmul_meta_info']
@meta_register.register(torch.nn.functional.linear)
@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.matmul)
def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.matmul meta info generator
There are several cases for torch.matmul:
1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
as two input vectors.
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
the same size as the input matrix, and allocate memory for the gradient of two inputs.
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
matrix will be stored in the memory allocated during the forward phase.
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
# Get input and output tensors
input_tensors = [args[0].data, args[1].data]
output_tensors = [args[-1].data]
# Check dimension
if all(len(tensor.shape) == 1 for tensor in input_tensors):
# Dot
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
# gemv case 1: matrix-vector multiplication
# &
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.mv.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
parameter=0,
temp=activation_size(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.mv.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
[input_tensors[1]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
0, 1)
], [output_tensors[0].transpose(-2, -1)])
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
[input_tensors[0]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
temp=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
else:
_is_batch_dims_same = False
# retireve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]
input_dim_11 = input_tensors[1].shape[-1]
output_dim_0 = output_tensors[0].shape[-2]
output_dim_1 = output_tensors[0].shape[-1]
if _is_batch_dims_same:
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
-1, input_dim_10, input_dim_11)
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_00,
input_dim_01,
device="meta")
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_10,
input_dim_11,
device="meta")
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0]
)
fwd_mem_cost = MemoryCost(
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
activation_size([extended_input_0, extended_input_1]),
temp=activation_size([extended_input_0, extended_input_1]))
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = input_tensors
fwd_buffer = []
fwd_out = output_tensors
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
class MatMulHandler(MetaInfoNodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on

View File

@ -16,6 +16,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
from colossalai.device.device_mesh import DeviceMesh
from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .strategy import StrategyGenerator
@ -266,6 +267,10 @@ class MetaInfoNodeHandler(NodeHandler):
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
else:
logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet, ')
return self.strategies_vector
@ -317,4 +322,8 @@ class MetaInfoModuleHandler(ModuleHandler):
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
else:
logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet')
return self.strategies_vector

View File

@ -20,7 +20,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
# There are three cases: 1) gemm, 2) gemv, 3) dot
if all(len(shape) == 2 for shape in input_shapes):
# gemm
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
elif all(len(shape) == 1 for shape in input_shapes):
# dot
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
# expand shape
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
else:
# gemv
if len(input_shapes[0]) == 1:
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
input_shapes.reverse()
else:
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
# expand the shape of the vector to [batch size, 1]
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops
@ -204,8 +225,10 @@ def zero_flop_jit(*args):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
flop_mapping = {
# gemm
# gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,

View File

@ -0,0 +1,145 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@parameterize(
'tensor_shapes',
[
[[128], [128]], # dot product
[[64, 128], [128]], # mat-vec
[[128], [128, 64]], # vec-mat
[[64, 64, 128], [128]], # batched mat-vec
[[128], [64, 128, 64]], # vec-batched mat
[[64, 128], [128, 192]], # mat-mat
[[64, 64, 128], [128, 192]], # batched mat-mat
[[64, 128], [64, 128, 192]], # mat-batched mat
[[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims)
[[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims)
])
def test_matmul_function_meta_info(tensor_shapes):
meta_func = meta_register.get(torch.matmul)
# construct meta tensors
input_tensor = torch.rand(*tensor_shapes[0], device="meta")
other_tensor = torch.rand(*tensor_shapes[1], device="meta")
output_tensor = torch.matmul(input_tensor, other_tensor)
# construct operation data
input_data = OperationData(
name="input",
data=input_tensor,
type=OperationDataType.ARG,
logical_shape=input_tensor.shape,
)
other_data = OperationData(
name="other",
data=other_tensor,
type=OperationDataType.ARG,
logical_shape=other_tensor.shape,
)
output_data = OperationData(
name="output",
data=output_tensor,
type=OperationDataType.OUTPUT,
logical_shape=output_tensor.shape,
)
# construct args and kwargs
args = [input_data, other_data, output_data]
kwargs = {'inplace': False}
# estimated results
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
# actual results
input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0")
other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0")
input_real_tensor.requires_grad = True
other_real_tensor.requires_grad = True
# fwd
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor)
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
# bwd
upstream_grad = torch.rand_like(output_real_tensor)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
torch.autograd.backward(output_real_tensor, upstream_grad)
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
compute_cost: TrainCycleItem
memory_cost: TrainCycleItem
print("=====================")
print(f"input shapes: {tensor_shapes[0]}, {tensor_shapes[1]}")
print(f"output shapes: {output_tensor.shape}")
# estimated results
print("Estimated Results")
# compute cost
print("compute_cost:")
print(f" fwd: {compute_cost.fwd}")
print(f" bwd: {compute_cost.bwd}")
# memory cost
print("memory_cost:")
# fwd
print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB")
print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB")
print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB")
print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB")
# bwd
print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB")
print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB")
print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB")
print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB")
# actual results
print("Actual Results")
print("memory_cost:")
# fwd
print(f" fwd allocated: {fwd_allocated / 1024} KB")
print(f" fwd peak: {fwd_peak / 1024} KB")
# bwd
print(f" bwd allocated: {bwd_allocated / 1024} KB")
print(f" bwd peak: {bwd_peak / 1024} KB")
if __name__ == '__main__':
test_matmul_function_meta_info()

View File

@ -1,3 +1,4 @@
import pytest
import torch
import torch.nn as nn
@ -24,6 +25,7 @@ class MatMulModule(nn.Module):
return torch.matmul(x1, x2)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@parameterize(
'tensor_shapes',
[