mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Patch meta information of `torch.matmul` (#2584)
* [autoparallel] matmul metainfo * [auto_parallel] remove unused print * [tests] skip test_matmul_handler when torch version is lower than 1.12.0pull/2581/head
parent
4ae02c4b1c
commit
90a9fdd91d
|
@ -1,3 +1,4 @@
|
||||||
|
from functools import reduce
|
||||||
from typing import Callable, Dict, List, Tuple, Union
|
from typing import Callable, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -16,7 +17,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
__all__ = ['linear_meta_info']
|
__all__ = ['linear_meta_info', 'matmul_meta_info']
|
||||||
|
|
||||||
|
|
||||||
@meta_register.register(torch.nn.functional.linear)
|
@meta_register.register(torch.nn.functional.linear)
|
||||||
|
@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||||
|
|
||||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||||
|
|
||||||
|
|
||||||
|
@meta_register.register(torch.matmul)
|
||||||
|
def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||||
|
"""torch.matmul meta info generator
|
||||||
|
There are several cases for torch.matmul:
|
||||||
|
1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
|
||||||
|
as two input vectors.
|
||||||
|
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
|
||||||
|
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
|
||||||
|
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
|
||||||
|
the same size as the input matrix, and allocate memory for the gradient of two inputs.
|
||||||
|
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
|
||||||
|
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
|
||||||
|
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
|
||||||
|
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
|
||||||
|
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
|
||||||
|
matrix will be stored in the memory allocated during the forward phase.
|
||||||
|
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
|
||||||
|
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
|
||||||
|
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
|
||||||
|
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
|
||||||
|
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
|
||||||
|
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
|
||||||
|
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
|
||||||
|
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
|
||||||
|
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get input and output tensors
|
||||||
|
input_tensors = [args[0].data, args[1].data]
|
||||||
|
output_tensors = [args[-1].data]
|
||||||
|
|
||||||
|
# Check dimension
|
||||||
|
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||||
|
# Dot
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
|
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||||
|
# gemv case 1: matrix-vector multiplication
|
||||||
|
# &
|
||||||
|
# batched gemv case 1: batched matrix-vector multiplication
|
||||||
|
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||||
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||||
|
|
||||||
|
# combine the dimensions of output
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||||
|
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||||
|
output_tensors) + \
|
||||||
|
flop_mapping[torch.ops.aten.mv.default](
|
||||||
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||||
|
output_tensors)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
|
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||||
|
# gemv case 2: vector-matrix multiplication
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
||||||
|
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||||
|
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
||||||
|
parameter=0,
|
||||||
|
temp=activation_size(input_tensors[1]),
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||||
|
# batched gemv case 2: vector-batched matrix multiplication
|
||||||
|
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||||
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||||
|
[output_tensors[0].reshape(-1)])
|
||||||
|
|
||||||
|
# combine the dimensions of output
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||||
|
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||||
|
output_tensors
|
||||||
|
) + \
|
||||||
|
flop_mapping[torch.ops.aten.mv.default](
|
||||||
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||||
|
output_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||||
|
parameter=0,
|
||||||
|
temp=activation_size(input_tensors[1]),
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||||
|
# gemm & batched gemm case 1: batched matrix-matrix multiplication
|
||||||
|
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
|
||||||
|
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
|
||||||
|
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||||
|
[input_tensors[1]]
|
||||||
|
) + \
|
||||||
|
flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||||
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||||
|
)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
|
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||||
|
# batched gemm case 2: matrix-batched matrix multiplication
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
|
||||||
|
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
|
||||||
|
0, 1)
|
||||||
|
], [output_tensors[0].transpose(-2, -1)])
|
||||||
|
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||||
|
[input_tensors[0]]
|
||||||
|
) + \
|
||||||
|
flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||||
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||||
|
)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
||||||
|
temp=activation_size(output_tensors))
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||||
|
parameter=0,
|
||||||
|
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
||||||
|
|
||||||
|
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||||
|
# Batched matrix-batched matrix multiplication
|
||||||
|
# Fetch shape of the two inputs and see if the batch dimensions are the same
|
||||||
|
_is_batch_dims_same = True
|
||||||
|
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
|
||||||
|
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||||
|
if shape_0 != shape_1:
|
||||||
|
_is_batch_dims_same = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
_is_batch_dims_same = False
|
||||||
|
|
||||||
|
# retireve dimensions
|
||||||
|
input_dim_00 = input_tensors[0].shape[-2]
|
||||||
|
input_dim_01 = input_tensors[0].shape[-1]
|
||||||
|
input_dim_10 = input_tensors[1].shape[-2]
|
||||||
|
input_dim_11 = input_tensors[1].shape[-1]
|
||||||
|
output_dim_0 = output_tensors[0].shape[-2]
|
||||||
|
output_dim_1 = output_tensors[0].shape[-1]
|
||||||
|
|
||||||
|
if _is_batch_dims_same:
|
||||||
|
# Case 1: batch dimensions are the same
|
||||||
|
|
||||||
|
# Forward compute cost: C = A * B
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
|
||||||
|
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
|
||||||
|
-1, input_dim_10, input_dim_11)
|
||||||
|
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||||
|
|
||||||
|
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||||
|
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||||
|
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
|
||||||
|
) + \
|
||||||
|
flop_mapping[torch.ops.aten.bmm.default](
|
||||||
|
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
|
||||||
|
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||||
|
)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Case 2: batch dimensions are different
|
||||||
|
batch_dims = output_tensors[0].shape[:-2]
|
||||||
|
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||||
|
input_dim_00,
|
||||||
|
input_dim_01,
|
||||||
|
device="meta")
|
||||||
|
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||||
|
input_dim_10,
|
||||||
|
input_dim_11,
|
||||||
|
device="meta")
|
||||||
|
|
||||||
|
# Forward compute cost: C = A * B
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||||
|
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||||
|
|
||||||
|
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||||
|
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||||
|
[extended_input_1]
|
||||||
|
) + \
|
||||||
|
flop_mapping[torch.ops.aten.bmm.default](
|
||||||
|
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||||
|
[extended_input_0]
|
||||||
|
)
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(
|
||||||
|
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
||||||
|
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
||||||
|
activation_size([extended_input_0, extended_input_1]),
|
||||||
|
temp=activation_size([extended_input_0, extended_input_1]))
|
||||||
|
|
||||||
|
# compute cost
|
||||||
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
|
# memory cost
|
||||||
|
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||||
|
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||||
|
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||||
|
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||||
|
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
|
||||||
|
|
||||||
|
# store fwd_in, fwd_buffer, fwd_out
|
||||||
|
fwd_in = input_tensors
|
||||||
|
fwd_buffer = []
|
||||||
|
fwd_out = output_tensors
|
||||||
|
|
||||||
|
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||||
|
|
|
@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import (
|
from .strategy import (
|
||||||
BatchedMatMulStrategyGenerator,
|
BatchedMatMulStrategyGenerator,
|
||||||
|
@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
||||||
|
|
||||||
@operator_registry.register(torch.matmul)
|
@operator_registry.register(torch.matmul)
|
||||||
@operator_registry.register(torch.Tensor.matmul)
|
@operator_registry.register(torch.Tensor.matmul)
|
||||||
class MatMulHandler(NodeHandler):
|
class MatMulHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||||
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
)
|
)
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
|
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
|
||||||
from .strategy import StrategyGenerator
|
from .strategy import StrategyGenerator
|
||||||
|
@ -266,6 +267,10 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(f'The target function {target} is not patched yet, ')
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
|
||||||
|
@ -317,4 +322,8 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(f'The target function {target} is not patched yet')
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
|
@ -20,7 +20,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||||
# Inputs contains the shapes of two matrices.
|
# Inputs contains the shapes of two matrices.
|
||||||
input_shapes = [v.shape for v in inputs]
|
input_shapes = [v.shape for v in inputs]
|
||||||
assert len(input_shapes) == 2, input_shapes
|
assert len(input_shapes) == 2, input_shapes
|
||||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
|
||||||
|
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||||
|
if all(len(shape) == 2 for shape in input_shapes):
|
||||||
|
# gemm
|
||||||
|
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||||
|
elif all(len(shape) == 1 for shape in input_shapes):
|
||||||
|
# dot
|
||||||
|
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||||
|
|
||||||
|
# expand shape
|
||||||
|
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||||
|
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||||
|
else:
|
||||||
|
# gemv
|
||||||
|
if len(input_shapes[0]) == 1:
|
||||||
|
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||||
|
input_shapes.reverse()
|
||||||
|
else:
|
||||||
|
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||||
|
|
||||||
|
# expand the shape of the vector to [batch size, 1]
|
||||||
|
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
@ -204,8 +225,10 @@ def zero_flop_jit(*args):
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
flop_mapping = {
|
flop_mapping = {
|
||||||
# gemm
|
# gemm, gemv and dot
|
||||||
aten.mm.default: matmul_flop_jit,
|
aten.mm.default: matmul_flop_jit,
|
||||||
|
aten.mv.default: matmul_flop_jit,
|
||||||
|
aten.dot.default: matmul_flop_jit,
|
||||||
aten.matmul.default: matmul_flop_jit,
|
aten.matmul.default: matmul_flop_jit,
|
||||||
aten.addmm.default: addmm_flop_jit,
|
aten.addmm.default: addmm_flop_jit,
|
||||||
aten.bmm.default: bmm_flop_jit,
|
aten.bmm.default: bmm_flop_jit,
|
||||||
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
MemoryCost,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||||
|
|
||||||
|
if torch.__version__ >= '1.12.0':
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||||
|
@parameterize(
|
||||||
|
'tensor_shapes',
|
||||||
|
[
|
||||||
|
[[128], [128]], # dot product
|
||||||
|
[[64, 128], [128]], # mat-vec
|
||||||
|
[[128], [128, 64]], # vec-mat
|
||||||
|
[[64, 64, 128], [128]], # batched mat-vec
|
||||||
|
[[128], [64, 128, 64]], # vec-batched mat
|
||||||
|
[[64, 128], [128, 192]], # mat-mat
|
||||||
|
[[64, 64, 128], [128, 192]], # batched mat-mat
|
||||||
|
[[64, 128], [64, 128, 192]], # mat-batched mat
|
||||||
|
[[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims)
|
||||||
|
[[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims)
|
||||||
|
])
|
||||||
|
def test_matmul_function_meta_info(tensor_shapes):
|
||||||
|
meta_func = meta_register.get(torch.matmul)
|
||||||
|
|
||||||
|
# construct meta tensors
|
||||||
|
input_tensor = torch.rand(*tensor_shapes[0], device="meta")
|
||||||
|
other_tensor = torch.rand(*tensor_shapes[1], device="meta")
|
||||||
|
output_tensor = torch.matmul(input_tensor, other_tensor)
|
||||||
|
|
||||||
|
# construct operation data
|
||||||
|
input_data = OperationData(
|
||||||
|
name="input",
|
||||||
|
data=input_tensor,
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
logical_shape=input_tensor.shape,
|
||||||
|
)
|
||||||
|
other_data = OperationData(
|
||||||
|
name="other",
|
||||||
|
data=other_tensor,
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
logical_shape=other_tensor.shape,
|
||||||
|
)
|
||||||
|
output_data = OperationData(
|
||||||
|
name="output",
|
||||||
|
data=output_tensor,
|
||||||
|
type=OperationDataType.OUTPUT,
|
||||||
|
logical_shape=output_tensor.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct args and kwargs
|
||||||
|
args = [input_data, other_data, output_data]
|
||||||
|
kwargs = {'inplace': False}
|
||||||
|
|
||||||
|
# estimated results
|
||||||
|
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# actual results
|
||||||
|
input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0")
|
||||||
|
other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0")
|
||||||
|
|
||||||
|
input_real_tensor.requires_grad = True
|
||||||
|
other_real_tensor.requires_grad = True
|
||||||
|
|
||||||
|
# fwd
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated()
|
||||||
|
output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor)
|
||||||
|
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
|
||||||
|
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
|
||||||
|
|
||||||
|
# bwd
|
||||||
|
upstream_grad = torch.rand_like(output_real_tensor)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated()
|
||||||
|
torch.autograd.backward(output_real_tensor, upstream_grad)
|
||||||
|
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
|
||||||
|
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
|
||||||
|
|
||||||
|
compute_cost: TrainCycleItem
|
||||||
|
memory_cost: TrainCycleItem
|
||||||
|
|
||||||
|
print("=====================")
|
||||||
|
print(f"input shapes: {tensor_shapes[0]}, {tensor_shapes[1]}")
|
||||||
|
print(f"output shapes: {output_tensor.shape}")
|
||||||
|
|
||||||
|
# estimated results
|
||||||
|
print("Estimated Results")
|
||||||
|
|
||||||
|
# compute cost
|
||||||
|
print("compute_cost:")
|
||||||
|
print(f" fwd: {compute_cost.fwd}")
|
||||||
|
print(f" bwd: {compute_cost.bwd}")
|
||||||
|
|
||||||
|
# memory cost
|
||||||
|
print("memory_cost:")
|
||||||
|
# fwd
|
||||||
|
print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB")
|
||||||
|
print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB")
|
||||||
|
print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB")
|
||||||
|
print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB")
|
||||||
|
|
||||||
|
# bwd
|
||||||
|
print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB")
|
||||||
|
print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB")
|
||||||
|
print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB")
|
||||||
|
print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB")
|
||||||
|
|
||||||
|
# actual results
|
||||||
|
print("Actual Results")
|
||||||
|
|
||||||
|
print("memory_cost:")
|
||||||
|
# fwd
|
||||||
|
print(f" fwd allocated: {fwd_allocated / 1024} KB")
|
||||||
|
print(f" fwd peak: {fwd_peak / 1024} KB")
|
||||||
|
|
||||||
|
# bwd
|
||||||
|
print(f" bwd allocated: {bwd_allocated / 1024} KB")
|
||||||
|
print(f" bwd peak: {bwd_peak / 1024} KB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_matmul_function_meta_info()
|
|
@ -1,3 +1,4 @@
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -24,6 +25,7 @@ class MatMulModule(nn.Module):
|
||||||
return torch.matmul(x1, x2)
|
return torch.matmul(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||||
@parameterize(
|
@parameterize(
|
||||||
'tensor_shapes',
|
'tensor_shapes',
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue