mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish codepull/1784/head
parent
4df0194976
commit
f3f19a5c47
|
@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||||
from .layer_norm_handler import LayerNormModuleHandler
|
from .layer_norm_handler import LayerNormModuleHandler
|
||||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||||
|
from .matmul_handler import MatMulHandler
|
||||||
from .normal_pooling_handler import NormPoolingHandler
|
from .normal_pooling_handler import NormPoolingHandler
|
||||||
from .output_handler import OuputHandler
|
from .output_handler import OuputHandler
|
||||||
from .placeholder_handler import PlacehodlerHandler
|
from .placeholder_handler import PlacehodlerHandler
|
||||||
|
@ -16,5 +17,5 @@ __all__ = [
|
||||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry'
|
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,482 @@
|
||||||
|
import operator
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from enum import Enum
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
|
||||||
|
BroadcastType,
|
||||||
|
get_broadcast_dim_info,
|
||||||
|
get_broadcast_shape,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||||
|
|
||||||
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||||
|
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||||
|
from .node_handler import NodeHandler
|
||||||
|
from .registry import operator_registry
|
||||||
|
from .strategy import (
|
||||||
|
BatchedMatMulStrategyGenerator,
|
||||||
|
DotProductStrategyGenerator,
|
||||||
|
LinearProjectionStrategyGenerator,
|
||||||
|
MatVecStrategyGenerator,
|
||||||
|
StrategyGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulType(Enum):
|
||||||
|
"""
|
||||||
|
The MatMulType is categorized into 4 types based on the reference of torch.matmul
|
||||||
|
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
|
||||||
|
|
||||||
|
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
|
||||||
|
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
|
||||||
|
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
|
||||||
|
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
|
||||||
|
"""
|
||||||
|
DOT = 0
|
||||||
|
MM = 1
|
||||||
|
MV = 2
|
||||||
|
BMM = 3
|
||||||
|
|
||||||
|
|
||||||
|
def get_matmul_type(input_dim: int, other_dim: int):
|
||||||
|
"""
|
||||||
|
Determine which type of matmul operation should be executed for the given tensor dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim (int): the number of dimensions for the input tenosr
|
||||||
|
other_dim (int): the number of dimensions for the other tenosr
|
||||||
|
"""
|
||||||
|
if input_dim == 1 and other_dim == 1:
|
||||||
|
matmul_type = MatMulType.DOT
|
||||||
|
elif input_dim in [1, 2] and other_dim == 2:
|
||||||
|
matmul_type = MatMulType.MM
|
||||||
|
elif input_dim == 2 and other_dim == 1:
|
||||||
|
matmul_type = MatMulType.MV
|
||||||
|
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
|
||||||
|
matmul_type = MatMulType.BMM
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
|
||||||
|
)
|
||||||
|
return matmul_type
|
||||||
|
|
||||||
|
|
||||||
|
class BmmTransform(ABC):
|
||||||
|
"""
|
||||||
|
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
|
||||||
|
during the strategy generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Padder(BmmTransform):
|
||||||
|
"""
|
||||||
|
Add padding to the matrix dimensions for batched matrix multiplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# keep the padding dim, op_name -> padded_dim
|
||||||
|
self.padded_dim_mapping = {}
|
||||||
|
|
||||||
|
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||||
|
mapping_copy = deepcopy(shape_mapping)
|
||||||
|
input_shape = mapping_copy['input']
|
||||||
|
other_shape = mapping_copy['other']
|
||||||
|
|
||||||
|
if len(input_shape) == 1:
|
||||||
|
# if the input is a 1D tensor, 1 is prepended to its shape
|
||||||
|
# and it will be removed afterwards
|
||||||
|
input_shape.insert(0, 1)
|
||||||
|
self.padded_dim_mapping['input'] = -2
|
||||||
|
self.padded_dim_mapping['output'] = -2
|
||||||
|
elif len(other_shape) == 1:
|
||||||
|
# if the other is a 1D tensor, 1 is appended to its shape
|
||||||
|
# and it will be removed afterwards
|
||||||
|
other_shape = other_shape.append(1)
|
||||||
|
self.padded_dim_mapping['other'] = -1
|
||||||
|
self.padded_dim_mapping['output'] = -1
|
||||||
|
return mapping_copy
|
||||||
|
|
||||||
|
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||||
|
input_op_data = op_data_mapping['input']
|
||||||
|
other_op_data = op_data_mapping['other']
|
||||||
|
|
||||||
|
def _remove_padded_dim(key, strategy):
|
||||||
|
op_data = op_data_mapping[key]
|
||||||
|
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||||
|
tensor_shape = list(sharding_spec.entire_shape)
|
||||||
|
dim_partition_list = [None] * len(tensor_shape)
|
||||||
|
|
||||||
|
# padded dim is a negative number as the padded dim must be a matrix dim
|
||||||
|
padded_dim = self.padded_dim_mapping[key]
|
||||||
|
|
||||||
|
# compute the new dim partition
|
||||||
|
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
|
||||||
|
dim_partition_list[tensor_dim] = mesh_dims
|
||||||
|
dim_partition_list.pop(padded_dim)
|
||||||
|
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
|
||||||
|
|
||||||
|
# compute unpadded tensor shape
|
||||||
|
tensor_shape.pop(padded_dim)
|
||||||
|
|
||||||
|
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
|
||||||
|
|
||||||
|
# update sharding spec
|
||||||
|
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
|
||||||
|
|
||||||
|
# enumerate all sharding strategies
|
||||||
|
strategies = []
|
||||||
|
try:
|
||||||
|
strategy_copy = strategy.clone()
|
||||||
|
|
||||||
|
# only one of input and other will be padded
|
||||||
|
if 'input' in self.padded_dim_mapping:
|
||||||
|
_remove_padded_dim('input', strategy_copy)
|
||||||
|
_remove_padded_dim('output', strategy_copy)
|
||||||
|
elif 'other' in self.padded_dim_mapping:
|
||||||
|
_remove_padded_dim('other', strategy_copy)
|
||||||
|
_remove_padded_dim('output', strategy_copy)
|
||||||
|
|
||||||
|
strategies.append(strategy_copy)
|
||||||
|
except ShardingSpecException as e:
|
||||||
|
pass
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
|
||||||
|
class Broadcaster(BmmTransform):
|
||||||
|
"""
|
||||||
|
Broadcast the non-matrix dimensions for batched matrix multiplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.broadcast_dim_info = {}
|
||||||
|
|
||||||
|
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||||
|
mapping_copy = shape_mapping.copy()
|
||||||
|
|
||||||
|
# get shapes
|
||||||
|
input_shape = mapping_copy['input']
|
||||||
|
other_shape = mapping_copy['other']
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
assert len(input_shape) > 1 and len(other_shape) > 1
|
||||||
|
|
||||||
|
# broadcast the batch dim and record
|
||||||
|
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
|
||||||
|
|
||||||
|
# store the broadcast dim info
|
||||||
|
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
|
||||||
|
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
|
||||||
|
self.broadcast_dim_info['input'] = input_broadcast_dim_info
|
||||||
|
self.broadcast_dim_info['other'] = other_broadcast_dim_info
|
||||||
|
|
||||||
|
# create the full logical shape
|
||||||
|
input_shape = bcast_non_matrix_dims + input_shape[-2:]
|
||||||
|
other_shape = bcast_non_matrix_dims + other_shape[-2:]
|
||||||
|
assert len(input_shape) == len(other_shape)
|
||||||
|
|
||||||
|
mapping_copy['input'] = input_shape
|
||||||
|
mapping_copy['other'] = other_shape
|
||||||
|
|
||||||
|
return mapping_copy
|
||||||
|
|
||||||
|
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||||
|
# remove sharding on the broadcast dim
|
||||||
|
def _remove_sharding_on_broadcast_dim(key, strategy):
|
||||||
|
op_data = op_data_mapping[key]
|
||||||
|
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||||
|
tensor_shape = list(sharding_spec.entire_shape)
|
||||||
|
|
||||||
|
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
|
||||||
|
if broadcast_type == BroadcastType.MULTIPLE:
|
||||||
|
# if the dim is originally 1 and multiplied during broadcast
|
||||||
|
# we set its sharding to R
|
||||||
|
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
|
||||||
|
# the dim 0 of [1, 2, 4] is multiplied to 4
|
||||||
|
tensor_shape[dim_idx] = 1
|
||||||
|
elif broadcast_type == BroadcastType.PADDDING:
|
||||||
|
# if the dim is padded
|
||||||
|
# we remove its sharding
|
||||||
|
tensor_shape[dim_idx] = None
|
||||||
|
|
||||||
|
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
||||||
|
|
||||||
|
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
|
||||||
|
logical_sharding_spec=sharding_spec,
|
||||||
|
logical_shape=sharding_spec.entire_shape,
|
||||||
|
physical_shape=tensor_shape_before_broadcast)
|
||||||
|
strategy.sharding_specs[op_data] = physical_sharding_spec
|
||||||
|
|
||||||
|
# enumerate all sharding strategies
|
||||||
|
strategies = []
|
||||||
|
try:
|
||||||
|
strategy_copy = strategy.clone()
|
||||||
|
_remove_sharding_on_broadcast_dim('input', strategy_copy)
|
||||||
|
_remove_sharding_on_broadcast_dim('other', strategy_copy)
|
||||||
|
strategies.append(strategy_copy)
|
||||||
|
except ShardingSpecException as e:
|
||||||
|
pass
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
|
||||||
|
class Viewer(BmmTransform):
|
||||||
|
"""
|
||||||
|
Change the shape of the tensor from N-D to 3D
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.batch_dims_before_view = None
|
||||||
|
|
||||||
|
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||||
|
mapping_copy = shape_mapping.copy()
|
||||||
|
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
|
||||||
|
|
||||||
|
# get shapes
|
||||||
|
input_shape = shape_mapping['input']
|
||||||
|
other_shape = shape_mapping['other']
|
||||||
|
|
||||||
|
# view to 3d tensor
|
||||||
|
assert len(input_shape) >= 3 and len(other_shape) >= 3
|
||||||
|
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
|
||||||
|
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
|
||||||
|
output_shape = input_shape[:2] + other_shape[2:]
|
||||||
|
mapping_copy['input'] = input_shape
|
||||||
|
mapping_copy['other'] = other_shape
|
||||||
|
mapping_copy['output'] = output_shape
|
||||||
|
return mapping_copy
|
||||||
|
|
||||||
|
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||||
|
# get operation data
|
||||||
|
def _update_sharding_spec(key, strategy, physical_batch_dim):
|
||||||
|
"""
|
||||||
|
Map the logical batch dim to the physical batch dim
|
||||||
|
"""
|
||||||
|
op_data = op_data_mapping[key]
|
||||||
|
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
entire_shape = sharding_spec.entire_shape
|
||||||
|
|
||||||
|
# upddate the dimension index for the matrix dimensions
|
||||||
|
if 2 in dim_partition_dict:
|
||||||
|
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
|
||||||
|
if 1 in dim_partition_dict:
|
||||||
|
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
|
||||||
|
|
||||||
|
# map the logical batch dim to phyiscal batch dim
|
||||||
|
if 0 in dim_partition_dict:
|
||||||
|
batch_dim_shard = dim_partition_dict.pop(0)
|
||||||
|
dim_partition_dict[physical_batch_dim] = batch_dim_shard
|
||||||
|
|
||||||
|
# the new shape will be the batch dims + the last 2 matrix dims
|
||||||
|
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
|
||||||
|
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
|
||||||
|
|
||||||
|
num_batch_dim_before_view = len(self.batch_dims_before_view)
|
||||||
|
|
||||||
|
# enumerate all sharding strategies
|
||||||
|
strategies = []
|
||||||
|
for i in range(num_batch_dim_before_view):
|
||||||
|
# create a new strategy
|
||||||
|
strategy_copy = strategy.clone()
|
||||||
|
try:
|
||||||
|
_update_sharding_spec('input', strategy_copy, i)
|
||||||
|
_update_sharding_spec('other', strategy_copy, i)
|
||||||
|
_update_sharding_spec('output', strategy_copy, i)
|
||||||
|
strategies.append(strategy_copy)
|
||||||
|
except ShardingSpecException as e:
|
||||||
|
continue
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
||||||
|
"""
|
||||||
|
Compute the logical shapes for BMM operation. BMM has a general representation
|
||||||
|
[b, i, k] = [b, i, j] x [b, j, k]
|
||||||
|
|
||||||
|
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
|
||||||
|
The logical shape for the bmm operands will undergo three stages
|
||||||
|
1. append/prepend the 1 to the 1D tensor if there is any
|
||||||
|
2. broadcast the non-matrix dimensions
|
||||||
|
3. reshape to 3 dimensions
|
||||||
|
|
||||||
|
"""
|
||||||
|
shape_mapping = {'input': input_shape, 'other': other_shape}
|
||||||
|
|
||||||
|
for transform in transforms:
|
||||||
|
shape_mapping = transform.apply(shape_mapping)
|
||||||
|
|
||||||
|
input_shape = shape_mapping.get('input', None)
|
||||||
|
other_shape = shape_mapping.get('other', None)
|
||||||
|
output_shape = shape_mapping.get('output', None)
|
||||||
|
|
||||||
|
return input_shape, other_shape, output_shape
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.matmul)
|
||||||
|
@operator_registry.register(torch.Tensor.matmul)
|
||||||
|
class MatMulHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||||
|
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||||
|
the operands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# check which type of operation this matmul will call
|
||||||
|
self.input_meta_data = self.node.args[0]._meta_data
|
||||||
|
self.other_meta_data = self.node.args[1]._meta_data
|
||||||
|
self.output_meta_data = self.node._meta_data
|
||||||
|
|
||||||
|
input_dim = self.input_meta_data.dim()
|
||||||
|
other_dim = self.other_meta_data.dim()
|
||||||
|
self.matmul_type = get_matmul_type(input_dim, other_dim)
|
||||||
|
|
||||||
|
if self.matmul_type == MatMulType.BMM:
|
||||||
|
# bmm operation can possibly involve padding, broadcasting and view
|
||||||
|
# these transforms will be used to create logical shape and
|
||||||
|
# recover physical sharding spec
|
||||||
|
self.transforms = [Padder(), Broadcaster(), Viewer()]
|
||||||
|
else:
|
||||||
|
self.transforms = None
|
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
|
generators = []
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
if self.matmul_type == MatMulType.BMM:
|
||||||
|
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
|
elif self.matmul_type == MatMulType.DOT:
|
||||||
|
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
|
elif self.matmul_type == MatMulType.MV:
|
||||||
|
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
|
elif self.matmul_type == MatMulType.MM:
|
||||||
|
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
logical_shape_func = {
|
||||||
|
MatMulType.DOT: self._get_logical_shape_for_dot,
|
||||||
|
MatMulType.MM: self._get_logical_shape_for_mm,
|
||||||
|
MatMulType.MV: self._get_logical_shape_for_mv,
|
||||||
|
MatMulType.BMM: self._get_logical_shape_for_bmm
|
||||||
|
}
|
||||||
|
logical_shapes = logical_shape_func[self.matmul_type]()
|
||||||
|
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
|
||||||
|
return op_data_mapping
|
||||||
|
|
||||||
|
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
|
||||||
|
# convert list to torch.Size
|
||||||
|
if input_logical_shape:
|
||||||
|
input_logical_shape = torch.Size(input_logical_shape)
|
||||||
|
|
||||||
|
if other_logical_shape:
|
||||||
|
other_logical_shape = torch.Size(other_logical_shape)
|
||||||
|
|
||||||
|
if output_logical_shape:
|
||||||
|
output_logical_shape = torch.Size(output_logical_shape)
|
||||||
|
|
||||||
|
# create op data
|
||||||
|
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.input_meta_data,
|
||||||
|
logical_shape=input_logical_shape)
|
||||||
|
other_op_data = OperationData(name=str(self.node.args[1]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.other_meta_data,
|
||||||
|
logical_shape=other_logical_shape)
|
||||||
|
output_op_data = OperationData(name=str(self.node),
|
||||||
|
type=OperationDataType.OUTPUT,
|
||||||
|
data=self.output_meta_data,
|
||||||
|
logical_shape=output_logical_shape)
|
||||||
|
|
||||||
|
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def _get_logical_shape_for_dot(self):
|
||||||
|
"""
|
||||||
|
The operands for the dot operation have the same logical shape as the physical shape
|
||||||
|
"""
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def _get_logical_shape_for_mm(self):
|
||||||
|
"""
|
||||||
|
We need to handle the input tensor for a matrix-matrix multiplcation as the input
|
||||||
|
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
|
||||||
|
(e.g. [4] -> [1, 4]).
|
||||||
|
"""
|
||||||
|
if self.input_meta_data.dim() == 1:
|
||||||
|
input_logical_shape = [1] + list(self.input_meta_data.shape)
|
||||||
|
input_logical_shape = torch.Size(input_logical_shape)
|
||||||
|
else:
|
||||||
|
input_logical_shape = None
|
||||||
|
return input_logical_shape, None, None
|
||||||
|
|
||||||
|
def _get_logical_shape_for_mv(self):
|
||||||
|
"""
|
||||||
|
No broadcasting or dim insertion occurs for matrix-vector operation.
|
||||||
|
"""
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def _get_logical_shape_for_bmm(self):
|
||||||
|
input_physical_shape = list(self.input_meta_data.shape)
|
||||||
|
other_physical_shape = list(self.other_meta_data.shape)
|
||||||
|
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
|
||||||
|
|
||||||
|
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||||
|
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
|
||||||
|
return strategy
|
||||||
|
elif self.matmul_type == MatMulType.MM:
|
||||||
|
if self.input_meta_data.dim() == 1:
|
||||||
|
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
|
||||||
|
# we need to remove that dim
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
|
||||||
|
input_physical_shape = self.node.args[0]._meta_data.shape
|
||||||
|
dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
|
# remove the partitioning in the dim 0
|
||||||
|
if 0 in dim_partition_dict:
|
||||||
|
dim_partition_dict.pop(0, None)
|
||||||
|
|
||||||
|
# move the partitioning in dim 1 to dim 0
|
||||||
|
if -1 in dim_partition_dict:
|
||||||
|
shard = dim_partition_dict.pop(-1)
|
||||||
|
dim_partition_dict[0] = shard
|
||||||
|
|
||||||
|
# re-init the sharding spec
|
||||||
|
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
||||||
|
entire_shape=input_physical_shape,
|
||||||
|
dim_partition_dict=dim_partition_dict)
|
||||||
|
return strategy
|
||||||
|
else:
|
||||||
|
return strategy
|
||||||
|
elif self.matmul_type == MatMulType.BMM:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
|
||||||
|
strategies = [strategy]
|
||||||
|
# recover the physical sharding spec
|
||||||
|
for transform in self.transforms[::-1]:
|
||||||
|
recovered_stragies = []
|
||||||
|
for strategy_ in strategies:
|
||||||
|
output = transform.recover(op_data_mapping, strategy_)
|
||||||
|
if isinstance(output, ShardingStrategy):
|
||||||
|
recovered_stragies.append(output)
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
recovered_stragies.extend(output)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
||||||
|
strategies = recovered_stragies
|
||||||
|
return strategies
|
|
@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||||
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||||
fwd_compute_cost = sharded_input_shape[0]
|
fwd_compute_cost = sharded_input_shape[0]
|
||||||
bwd_compute_cost = sharded_input_shape * 2
|
bwd_compute_cost = fwd_compute_cost * 2
|
||||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||||
bwd=bwd_compute_cost,
|
bwd=bwd_compute_cost,
|
||||||
total=fwd_compute_cost + bwd_compute_cost)
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
return compute_cost
|
return compute_cost
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def no_split(self):
|
def no_split(self):
|
||||||
name = f'R = R dot R'
|
name = f'R = R dot R'
|
||||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
|
||||||
|
@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_one_dim(self, mesh_dim):
|
def split_one_dim(self, mesh_dim):
|
||||||
name = f'R = S{mesh_dim} dot S{mesh_dim}'
|
name = f'R = S{mesh_dim} dot S{mesh_dim}'
|
||||||
|
|
||||||
|
@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy]:
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
strategy_list = []
|
strategy_list = []
|
||||||
|
|
||||||
# do not split dimensions for dot product
|
# do not split dimensions for dot product
|
||||||
|
@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
input_op_data = self.op_data['input']
|
input_op_data = self.op_data['input']
|
||||||
other_op_data = self.op_data['other']
|
other_op_data = self.op_data['other']
|
||||||
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
|
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||||
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||||
|
fwd_compute_cost = sharded_input_shape[0]
|
||||||
|
bwd_compute_cost = fwd_compute_cost * 2
|
||||||
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||||
|
bwd=bwd_compute_cost,
|
||||||
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
return compute_cost
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def no_split(self):
|
def no_split(self):
|
||||||
name = "R = R x R"
|
name = "R = R x R"
|
||||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
dim_partition_dict['bias'] = {}
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping={})
|
communication_action_mapping={})
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_batch(self, mesh_dim):
|
def split_input_batch(self, mesh_dim):
|
||||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||||
|
|
||||||
# get sharding spec
|
# get sharding spec
|
||||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
"other": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
dim_partition_dict['bias'] = {}
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
# get communication action
|
# get communication action
|
||||||
|
communication_action_mapping = {}
|
||||||
if self.is_param('other'):
|
if self.is_param('other'):
|
||||||
other_comm_action = self.get_communication_action(
|
other_comm_action = self.get_communication_action(
|
||||||
sharding_spec=sharding_spec_mapping['other'],
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
|
@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
logical_process_axis=mesh_dim,
|
logical_process_axis=mesh_dim,
|
||||||
comm_type=CommType.BEFORE,
|
comm_type=CommType.BEFORE,
|
||||||
arg_index=1)
|
arg_index=1)
|
||||||
|
communication_action_mapping['other'] = other_comm_action
|
||||||
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if self.is_param('bias'):
|
if self.is_param('bias'):
|
||||||
bias_comm_action = self.get_communication_action(
|
bias_comm_action = self.get_communication_action(
|
||||||
|
@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
logical_process_axis=mesh_dim,
|
logical_process_axis=mesh_dim,
|
||||||
comm_type=CommType.BEFORE,
|
comm_type=CommType.BEFORE,
|
||||||
arg_index=2)
|
arg_index=2)
|
||||||
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
|
communication_action_mapping['bias'] = bias_comm_action
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy]:
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
strategy_list = []
|
strategy_list = []
|
||||||
|
|
||||||
# no split
|
# no split
|
||||||
|
@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
input_op_data = self.op_data['input']
|
input_op_data = self.op_data['input']
|
||||||
other_op_data = self.op_data['other']
|
other_op_data = self.op_data['other']
|
||||||
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
|
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
|
||||||
|
|
||||||
if 'bias' in self.op_data:
|
if 'bias' in self.op_data:
|
||||||
bias_op_data = self.op_data['bias']
|
bias_op_data = self.op_data['bias']
|
||||||
|
@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
dim_partition_dict = {
|
dim_partition_dict = {
|
||||||
"input": {
|
"input": {
|
||||||
0: [mesh_dim_0],
|
0: [mesh_dim_0],
|
||||||
-1: [mesh_dim_1]
|
2: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
0: [mesh_dim_0],
|
0: [mesh_dim_0],
|
||||||
-2: [mesh_dim_1]
|
1: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {},
|
"bias": {},
|
||||||
"output": {
|
"output": {
|
||||||
|
|
|
@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
|
||||||
"""
|
"""
|
||||||
op_data = self.op_data[key]
|
op_data = self.op_data[key]
|
||||||
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
||||||
|
|
||||||
|
if len(sharded_shape) == 0:
|
||||||
|
num_elements = 1
|
||||||
|
else:
|
||||||
|
num_elements = reduce(operator.mul, sharded_shape)
|
||||||
dtype = self.op_data[key].data.dtype
|
dtype = self.op_data[key].data.dtype
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
|
return num_elements * size_per_elem_bytes
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy]:
|
def generate(self) -> List[ShardingStrategy]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
||||||
return dims[::-1]
|
return dims[::-1]
|
||||||
|
|
||||||
|
|
||||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||||
physical_shape: torch.Size) -> ShardingSpec:
|
|
||||||
"""
|
|
||||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
|
||||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
|
||||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
|
||||||
"""
|
|
||||||
# if the two shapes are the same, no broadcast occurs
|
|
||||||
# we directly return the current sharding spec
|
|
||||||
if list(logical_shape) == list(physical_shape):
|
|
||||||
return logical_sharding_spec
|
|
||||||
|
|
||||||
# get the number of dimensions
|
# get the number of dimensions
|
||||||
logical_num_dims = len(logical_shape)
|
logical_num_dims = len(logical_shape)
|
||||||
physical_num_dims = len(physical_shape)
|
physical_num_dims = len(physical_shape)
|
||||||
|
@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
||||||
else:
|
else:
|
||||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
||||||
|
|
||||||
|
return logical_dim_broadcast_info
|
||||||
|
|
||||||
|
|
||||||
|
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||||
|
physical_shape: torch.Size) -> ShardingSpec:
|
||||||
|
"""
|
||||||
|
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||||
|
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||||
|
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||||
|
"""
|
||||||
|
# if the two shapes are the same, no broadcast occurs
|
||||||
|
# we directly return the current sharding spec
|
||||||
|
if list(logical_shape) == list(physical_shape):
|
||||||
|
return logical_sharding_spec
|
||||||
|
|
||||||
|
# get the number of dimensions
|
||||||
|
logical_num_dims = len(logical_shape)
|
||||||
|
physical_num_dims = len(physical_shape)
|
||||||
|
|
||||||
|
# get the broadcast info
|
||||||
|
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
|
||||||
|
|
||||||
# generate the sharding spec for the physical shape
|
# generate the sharding spec for the physical shape
|
||||||
physical_dim_partition = {}
|
physical_dim_partition = {}
|
||||||
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import operator
|
import operator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -175,6 +174,9 @@ class ShardingSpec:
|
||||||
dim_partition_dict=None,
|
dim_partition_dict=None,
|
||||||
sharding_sequence=None):
|
sharding_sequence=None):
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
|
if isinstance(entire_shape, (list, tuple)):
|
||||||
|
entire_shape = torch.Size(entire_shape)
|
||||||
self.entire_shape = entire_shape
|
self.entire_shape = entire_shape
|
||||||
self.dim_partition_dict = dim_partition_dict
|
self.dim_partition_dict = dim_partition_dict
|
||||||
self.sharding_sequence = sharding_sequence
|
self.sharding_sequence = sharding_sequence
|
||||||
|
|
|
@ -0,0 +1,166 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
|
||||||
|
MatMulHandler,
|
||||||
|
MatMulType,
|
||||||
|
_get_bmm_logical_shape,
|
||||||
|
get_matmul_type,
|
||||||
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing.utils import parameterize
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulModule(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
return torch.matmul(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
'tensor_shapes',
|
||||||
|
[
|
||||||
|
[[8], [8]], # dot product
|
||||||
|
[[4, 8], [8]], # mat-vec product
|
||||||
|
[[4, 8], [8, 16]], # mat-mat product
|
||||||
|
[[8], [8, 16]], # mat-mat product
|
||||||
|
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
|
||||||
|
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
|
||||||
|
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||||
|
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
|
||||||
|
])
|
||||||
|
def test_matmul_node_handler(tensor_shapes):
|
||||||
|
input_shape, other_shape = tensor_shapes
|
||||||
|
|
||||||
|
# get output shape
|
||||||
|
x1 = torch.rand(*input_shape)
|
||||||
|
x2 = torch.rand(*other_shape)
|
||||||
|
output_shape = list(torch.matmul(x1, x2).shape)
|
||||||
|
|
||||||
|
# get matmul type
|
||||||
|
matmul_type = get_matmul_type(x1.dim(), x2.dim())
|
||||||
|
|
||||||
|
model = MatMulModule()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
print(graph)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
mod_node = list(graph.nodes)[2]
|
||||||
|
strategies_vector = StrategiesVector(mod_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
|
# check operation data mapping
|
||||||
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
for name, op_data in mapping.items():
|
||||||
|
op_data: OperationData
|
||||||
|
# make sure they have valid values
|
||||||
|
assert op_data.logical_shape is not None
|
||||||
|
assert op_data.data is not None
|
||||||
|
|
||||||
|
logical_input_shape = input_shape
|
||||||
|
logical_other_shape = other_shape
|
||||||
|
logical_output_shape = output_shape
|
||||||
|
if matmul_type == MatMulType.MM and len(input_shape) == 1:
|
||||||
|
logical_input_shape = [1] + input_shape
|
||||||
|
elif matmul_type == MatMulType.BMM:
|
||||||
|
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
|
||||||
|
input_shape, other_shape, handler.transforms)
|
||||||
|
else:
|
||||||
|
logical_input_shape = input_shape
|
||||||
|
|
||||||
|
# check input operation data
|
||||||
|
assert mapping['input'].name == "x1"
|
||||||
|
assert mapping['input'].data.is_meta
|
||||||
|
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||||
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
|
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
|
||||||
|
|
||||||
|
# check other operation data
|
||||||
|
assert mapping['other'].name == "x2"
|
||||||
|
assert mapping['other'].data.is_meta
|
||||||
|
assert mapping['other'].data.shape == torch.Size(other_shape)
|
||||||
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
|
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
|
||||||
|
|
||||||
|
# check output
|
||||||
|
assert mapping['output'].name == "matmul"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
|
||||||
|
|
||||||
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
|
||||||
|
# ensure there is no duplicate strategy
|
||||||
|
if matmul_type != MatMulType.BMM:
|
||||||
|
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
|
||||||
|
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
strategy: ShardingStrategy
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||||
|
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||||
|
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
|
||||||
|
|
||||||
|
if matmul_type == MatMulType.DOT:
|
||||||
|
# dot product will produce a scaler
|
||||||
|
# results should fulfill:
|
||||||
|
# 1. the input and other operands have the same sharding spec
|
||||||
|
# 2. the output has no sharding
|
||||||
|
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
|
||||||
|
assert len(output_sharding_spec.sharding_sequence) == 0
|
||||||
|
elif matmul_type == MatMulType.MV:
|
||||||
|
# matrix-vector product should fulfill
|
||||||
|
# 1. the last dim of the input and other operands should have the same sharding
|
||||||
|
# 2. the first dim of the input and other should have the same sharding
|
||||||
|
# 3. the output should have only 1 dim
|
||||||
|
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
|
assert len(output_sharding_spec.sharding_sequence) == 1
|
||||||
|
elif matmul_type == MatMulType.MM:
|
||||||
|
# matrix-matrix multiplication should fulfil
|
||||||
|
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
|
||||||
|
# 2. the input's last dim and the first dim of the other should have the same sharding
|
||||||
|
# 3. the last dim of the output and other should have the same sharding
|
||||||
|
# 4. the input and output should have the same number of dims
|
||||||
|
if len(input_shape) == 2:
|
||||||
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
|
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
|
||||||
|
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||||
|
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
|
||||||
|
elif matmul_type == MatMulType.BMM:
|
||||||
|
# bmm should fulfil
|
||||||
|
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
|
||||||
|
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
|
||||||
|
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
|
||||||
|
if len(other_shape) > 1:
|
||||||
|
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
if len(input_shape) > 1:
|
||||||
|
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
|
||||||
|
if len(other_shape) > 2:
|
||||||
|
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_matmul_node_handler()
|
Loading…
Reference in New Issue