mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish codepull/1784/head
parent
4df0194976
commit
f3f19a5c47
|
@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
|||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
|
@ -16,5 +17,5 @@ __all__ = [
|
|||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry'
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,482 @@
|
|||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
|
||||
BroadcastType,
|
||||
get_broadcast_dim_info,
|
||||
get_broadcast_shape,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (
|
||||
BatchedMatMulStrategyGenerator,
|
||||
DotProductStrategyGenerator,
|
||||
LinearProjectionStrategyGenerator,
|
||||
MatVecStrategyGenerator,
|
||||
StrategyGenerator,
|
||||
)
|
||||
|
||||
|
||||
class MatMulType(Enum):
|
||||
"""
|
||||
The MatMulType is categorized into 4 types based on the reference of torch.matmul
|
||||
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
|
||||
|
||||
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
|
||||
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
|
||||
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
|
||||
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
|
||||
"""
|
||||
DOT = 0
|
||||
MM = 1
|
||||
MV = 2
|
||||
BMM = 3
|
||||
|
||||
|
||||
def get_matmul_type(input_dim: int, other_dim: int):
|
||||
"""
|
||||
Determine which type of matmul operation should be executed for the given tensor dimensions.
|
||||
|
||||
Args:
|
||||
input_dim (int): the number of dimensions for the input tenosr
|
||||
other_dim (int): the number of dimensions for the other tenosr
|
||||
"""
|
||||
if input_dim == 1 and other_dim == 1:
|
||||
matmul_type = MatMulType.DOT
|
||||
elif input_dim in [1, 2] and other_dim == 2:
|
||||
matmul_type = MatMulType.MM
|
||||
elif input_dim == 2 and other_dim == 1:
|
||||
matmul_type = MatMulType.MV
|
||||
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
|
||||
matmul_type = MatMulType.BMM
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
|
||||
)
|
||||
return matmul_type
|
||||
|
||||
|
||||
class BmmTransform(ABC):
|
||||
"""
|
||||
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
|
||||
during the strategy generation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
pass
|
||||
|
||||
|
||||
class Padder(BmmTransform):
|
||||
"""
|
||||
Add padding to the matrix dimensions for batched matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# keep the padding dim, op_name -> padded_dim
|
||||
self.padded_dim_mapping = {}
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = deepcopy(shape_mapping)
|
||||
input_shape = mapping_copy['input']
|
||||
other_shape = mapping_copy['other']
|
||||
|
||||
if len(input_shape) == 1:
|
||||
# if the input is a 1D tensor, 1 is prepended to its shape
|
||||
# and it will be removed afterwards
|
||||
input_shape.insert(0, 1)
|
||||
self.padded_dim_mapping['input'] = -2
|
||||
self.padded_dim_mapping['output'] = -2
|
||||
elif len(other_shape) == 1:
|
||||
# if the other is a 1D tensor, 1 is appended to its shape
|
||||
# and it will be removed afterwards
|
||||
other_shape = other_shape.append(1)
|
||||
self.padded_dim_mapping['other'] = -1
|
||||
self.padded_dim_mapping['output'] = -1
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
input_op_data = op_data_mapping['input']
|
||||
other_op_data = op_data_mapping['other']
|
||||
|
||||
def _remove_padded_dim(key, strategy):
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
tensor_shape = list(sharding_spec.entire_shape)
|
||||
dim_partition_list = [None] * len(tensor_shape)
|
||||
|
||||
# padded dim is a negative number as the padded dim must be a matrix dim
|
||||
padded_dim = self.padded_dim_mapping[key]
|
||||
|
||||
# compute the new dim partition
|
||||
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
|
||||
dim_partition_list[tensor_dim] = mesh_dims
|
||||
dim_partition_list.pop(padded_dim)
|
||||
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
|
||||
|
||||
# compute unpadded tensor shape
|
||||
tensor_shape.pop(padded_dim)
|
||||
|
||||
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
|
||||
|
||||
# update sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
try:
|
||||
strategy_copy = strategy.clone()
|
||||
|
||||
# only one of input and other will be padded
|
||||
if 'input' in self.padded_dim_mapping:
|
||||
_remove_padded_dim('input', strategy_copy)
|
||||
_remove_padded_dim('output', strategy_copy)
|
||||
elif 'other' in self.padded_dim_mapping:
|
||||
_remove_padded_dim('other', strategy_copy)
|
||||
_remove_padded_dim('output', strategy_copy)
|
||||
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
pass
|
||||
return strategies
|
||||
|
||||
|
||||
class Broadcaster(BmmTransform):
|
||||
"""
|
||||
Broadcast the non-matrix dimensions for batched matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.broadcast_dim_info = {}
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = shape_mapping.copy()
|
||||
|
||||
# get shapes
|
||||
input_shape = mapping_copy['input']
|
||||
other_shape = mapping_copy['other']
|
||||
|
||||
# sanity check
|
||||
assert len(input_shape) > 1 and len(other_shape) > 1
|
||||
|
||||
# broadcast the batch dim and record
|
||||
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
|
||||
|
||||
# store the broadcast dim info
|
||||
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
|
||||
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
|
||||
self.broadcast_dim_info['input'] = input_broadcast_dim_info
|
||||
self.broadcast_dim_info['other'] = other_broadcast_dim_info
|
||||
|
||||
# create the full logical shape
|
||||
input_shape = bcast_non_matrix_dims + input_shape[-2:]
|
||||
other_shape = bcast_non_matrix_dims + other_shape[-2:]
|
||||
assert len(input_shape) == len(other_shape)
|
||||
|
||||
mapping_copy['input'] = input_shape
|
||||
mapping_copy['other'] = other_shape
|
||||
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
# remove sharding on the broadcast dim
|
||||
def _remove_sharding_on_broadcast_dim(key, strategy):
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
tensor_shape = list(sharding_spec.entire_shape)
|
||||
|
||||
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
|
||||
if broadcast_type == BroadcastType.MULTIPLE:
|
||||
# if the dim is originally 1 and multiplied during broadcast
|
||||
# we set its sharding to R
|
||||
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
|
||||
# the dim 0 of [1, 2, 4] is multiplied to 4
|
||||
tensor_shape[dim_idx] = 1
|
||||
elif broadcast_type == BroadcastType.PADDDING:
|
||||
# if the dim is padded
|
||||
# we remove its sharding
|
||||
tensor_shape[dim_idx] = None
|
||||
|
||||
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
|
||||
|
||||
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
|
||||
logical_sharding_spec=sharding_spec,
|
||||
logical_shape=sharding_spec.entire_shape,
|
||||
physical_shape=tensor_shape_before_broadcast)
|
||||
strategy.sharding_specs[op_data] = physical_sharding_spec
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
try:
|
||||
strategy_copy = strategy.clone()
|
||||
_remove_sharding_on_broadcast_dim('input', strategy_copy)
|
||||
_remove_sharding_on_broadcast_dim('other', strategy_copy)
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
pass
|
||||
return strategies
|
||||
|
||||
|
||||
class Viewer(BmmTransform):
|
||||
"""
|
||||
Change the shape of the tensor from N-D to 3D
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.batch_dims_before_view = None
|
||||
|
||||
def apply(self, shape_mapping: Dict[str, List[int]]):
|
||||
mapping_copy = shape_mapping.copy()
|
||||
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
|
||||
|
||||
# get shapes
|
||||
input_shape = shape_mapping['input']
|
||||
other_shape = shape_mapping['other']
|
||||
|
||||
# view to 3d tensor
|
||||
assert len(input_shape) >= 3 and len(other_shape) >= 3
|
||||
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
|
||||
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
|
||||
output_shape = input_shape[:2] + other_shape[2:]
|
||||
mapping_copy['input'] = input_shape
|
||||
mapping_copy['other'] = other_shape
|
||||
mapping_copy['output'] = output_shape
|
||||
return mapping_copy
|
||||
|
||||
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
|
||||
# get operation data
|
||||
def _update_sharding_spec(key, strategy, physical_batch_dim):
|
||||
"""
|
||||
Map the logical batch dim to the physical batch dim
|
||||
"""
|
||||
op_data = op_data_mapping[key]
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
entire_shape = sharding_spec.entire_shape
|
||||
|
||||
# upddate the dimension index for the matrix dimensions
|
||||
if 2 in dim_partition_dict:
|
||||
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
|
||||
if 1 in dim_partition_dict:
|
||||
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
|
||||
|
||||
# map the logical batch dim to phyiscal batch dim
|
||||
if 0 in dim_partition_dict:
|
||||
batch_dim_shard = dim_partition_dict.pop(0)
|
||||
dim_partition_dict[physical_batch_dim] = batch_dim_shard
|
||||
|
||||
# the new shape will be the batch dims + the last 2 matrix dims
|
||||
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
|
||||
|
||||
num_batch_dim_before_view = len(self.batch_dims_before_view)
|
||||
|
||||
# enumerate all sharding strategies
|
||||
strategies = []
|
||||
for i in range(num_batch_dim_before_view):
|
||||
# create a new strategy
|
||||
strategy_copy = strategy.clone()
|
||||
try:
|
||||
_update_sharding_spec('input', strategy_copy, i)
|
||||
_update_sharding_spec('other', strategy_copy, i)
|
||||
_update_sharding_spec('output', strategy_copy, i)
|
||||
strategies.append(strategy_copy)
|
||||
except ShardingSpecException as e:
|
||||
continue
|
||||
return strategies
|
||||
|
||||
|
||||
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
||||
"""
|
||||
Compute the logical shapes for BMM operation. BMM has a general representation
|
||||
[b, i, k] = [b, i, j] x [b, j, k]
|
||||
|
||||
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
|
||||
The logical shape for the bmm operands will undergo three stages
|
||||
1. append/prepend the 1 to the 1D tensor if there is any
|
||||
2. broadcast the non-matrix dimensions
|
||||
3. reshape to 3 dimensions
|
||||
|
||||
"""
|
||||
shape_mapping = {'input': input_shape, 'other': other_shape}
|
||||
|
||||
for transform in transforms:
|
||||
shape_mapping = transform.apply(shape_mapping)
|
||||
|
||||
input_shape = shape_mapping.get('input', None)
|
||||
other_shape = shape_mapping.get('other', None)
|
||||
output_shape = shape_mapping.get('output', None)
|
||||
|
||||
return input_shape, other_shape, output_shape
|
||||
|
||||
|
||||
@operator_registry.register(torch.matmul)
|
||||
@operator_registry.register(torch.Tensor.matmul)
|
||||
class MatMulHandler(NodeHandler):
|
||||
"""
|
||||
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||
the operands.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# check which type of operation this matmul will call
|
||||
self.input_meta_data = self.node.args[0]._meta_data
|
||||
self.other_meta_data = self.node.args[1]._meta_data
|
||||
self.output_meta_data = self.node._meta_data
|
||||
|
||||
input_dim = self.input_meta_data.dim()
|
||||
other_dim = self.other_meta_data.dim()
|
||||
self.matmul_type = get_matmul_type(input_dim, other_dim)
|
||||
|
||||
if self.matmul_type == MatMulType.BMM:
|
||||
# bmm operation can possibly involve padding, broadcasting and view
|
||||
# these transforms will be used to create logical shape and
|
||||
# recover physical sharding spec
|
||||
self.transforms = [Padder(), Broadcaster(), Viewer()]
|
||||
else:
|
||||
self.transforms = None
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
generators = []
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
if self.matmul_type == MatMulType.BMM:
|
||||
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.DOT:
|
||||
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.MV:
|
||||
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
elif self.matmul_type == MatMulType.MM:
|
||||
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
logical_shape_func = {
|
||||
MatMulType.DOT: self._get_logical_shape_for_dot,
|
||||
MatMulType.MM: self._get_logical_shape_for_mm,
|
||||
MatMulType.MV: self._get_logical_shape_for_mv,
|
||||
MatMulType.BMM: self._get_logical_shape_for_bmm
|
||||
}
|
||||
logical_shapes = logical_shape_func[self.matmul_type]()
|
||||
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
|
||||
return op_data_mapping
|
||||
|
||||
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
|
||||
# convert list to torch.Size
|
||||
if input_logical_shape:
|
||||
input_logical_shape = torch.Size(input_logical_shape)
|
||||
|
||||
if other_logical_shape:
|
||||
other_logical_shape = torch.Size(other_logical_shape)
|
||||
|
||||
if output_logical_shape:
|
||||
output_logical_shape = torch.Size(output_logical_shape)
|
||||
|
||||
# create op data
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.input_meta_data,
|
||||
logical_shape=input_logical_shape)
|
||||
other_op_data = OperationData(name=str(self.node.args[1]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.other_meta_data,
|
||||
logical_shape=other_logical_shape)
|
||||
output_op_data = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=self.output_meta_data,
|
||||
logical_shape=output_logical_shape)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
||||
def _get_logical_shape_for_dot(self):
|
||||
"""
|
||||
The operands for the dot operation have the same logical shape as the physical shape
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _get_logical_shape_for_mm(self):
|
||||
"""
|
||||
We need to handle the input tensor for a matrix-matrix multiplcation as the input
|
||||
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
|
||||
(e.g. [4] -> [1, 4]).
|
||||
"""
|
||||
if self.input_meta_data.dim() == 1:
|
||||
input_logical_shape = [1] + list(self.input_meta_data.shape)
|
||||
input_logical_shape = torch.Size(input_logical_shape)
|
||||
else:
|
||||
input_logical_shape = None
|
||||
return input_logical_shape, None, None
|
||||
|
||||
def _get_logical_shape_for_mv(self):
|
||||
"""
|
||||
No broadcasting or dim insertion occurs for matrix-vector operation.
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _get_logical_shape_for_bmm(self):
|
||||
input_physical_shape = list(self.input_meta_data.shape)
|
||||
other_physical_shape = list(self.other_meta_data.shape)
|
||||
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
|
||||
return strategy
|
||||
elif self.matmul_type == MatMulType.MM:
|
||||
if self.input_meta_data.dim() == 1:
|
||||
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
|
||||
# we need to remove that dim
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
|
||||
input_physical_shape = self.node.args[0]._meta_data.shape
|
||||
dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
|
||||
# remove the partitioning in the dim 0
|
||||
if 0 in dim_partition_dict:
|
||||
dim_partition_dict.pop(0, None)
|
||||
|
||||
# move the partitioning in dim 1 to dim 0
|
||||
if -1 in dim_partition_dict:
|
||||
shard = dim_partition_dict.pop(-1)
|
||||
dim_partition_dict[0] = shard
|
||||
|
||||
# re-init the sharding spec
|
||||
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
||||
entire_shape=input_physical_shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return strategy
|
||||
else:
|
||||
return strategy
|
||||
elif self.matmul_type == MatMulType.BMM:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
strategies = [strategy]
|
||||
# recover the physical sharding spec
|
||||
for transform in self.transforms[::-1]:
|
||||
recovered_stragies = []
|
||||
for strategy_ in strategies:
|
||||
output = transform.recover(op_data_mapping, strategy_)
|
||||
if isinstance(output, ShardingStrategy):
|
||||
recovered_stragies.append(output)
|
||||
elif isinstance(output, (list, tuple)):
|
||||
recovered_stragies.extend(output)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
||||
strategies = recovered_stragies
|
||||
return strategies
|
|
@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
|||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||
fwd_compute_cost = sharded_input_shape[0]
|
||||
bwd_compute_cost = sharded_input_shape * 2
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
return compute_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def no_split(self):
|
||||
name = f'R = R dot R'
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
|
||||
|
@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_one_dim(self, mesh_dim):
|
||||
name = f'R = S{mesh_dim} dot S{mesh_dim}'
|
||||
|
||||
|
@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# do not split dimensions for dot product
|
||||
|
@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
def validate(self) -> bool:
|
||||
input_op_data = self.op_data['input']
|
||||
other_op_data = self.op_data['other']
|
||||
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
|
||||
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||
fwd_compute_cost = sharded_input_shape[0]
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
return compute_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def no_split(self):
|
||||
name = "R = R x R"
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
|
||||
if self.has_bias:
|
||||
dim_partition_dict['bias'] = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim):
|
||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||
|
||||
# get sharding spec
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"other": {},
|
||||
"output": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
}
|
||||
|
||||
if self.has_bias:
|
||||
dim_partition_dict['bias'] = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
if self.is_param('other'):
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
|
@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
communication_action_mapping['other'] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
|
@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=2)
|
||||
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# no split
|
||||
|
@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
def validate(self) -> bool:
|
||||
input_op_data = self.op_data['input']
|
||||
other_op_data = self.op_data['other']
|
||||
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
|
||||
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
|
||||
|
||||
if 'bias' in self.op_data:
|
||||
bias_op_data = self.op_data['bias']
|
||||
|
@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
2: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
1: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
|
|
|
@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
|
|||
"""
|
||||
op_data = self.op_data[key]
|
||||
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
||||
|
||||
if len(sharded_shape) == 0:
|
||||
num_elements = 1
|
||||
else:
|
||||
num_elements = reduce(operator.mul, sharded_shape)
|
||||
dtype = self.op_data[key].data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
|
||||
return num_elements * size_per_elem_bytes
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
"""
|
||||
|
|
|
@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
|||
return dims[::-1]
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
Args:
|
||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
|
||||
def get_broadcast_dim_info(logical_shape, physical_shape):
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
|
|||
else:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
||||
|
||||
return logical_dim_broadcast_info
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
Args:
|
||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# if the two shapes are the same, no broadcast occurs
|
||||
# we directly return the current sharding spec
|
||||
if list(logical_shape) == list(physical_shape):
|
||||
return logical_sharding_spec
|
||||
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
||||
# get the broadcast info
|
||||
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
|
||||
|
||||
# generate the sharding spec for the physical shape
|
||||
physical_dim_partition = {}
|
||||
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import operator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
@ -175,6 +174,9 @@ class ShardingSpec:
|
|||
dim_partition_dict=None,
|
||||
sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
if isinstance(entire_shape, (list, tuple)):
|
||||
entire_shape = torch.Size(entire_shape)
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
|
||||
MatMulHandler,
|
||||
MatMulType,
|
||||
_get_bmm_logical_shape,
|
||||
get_matmul_type,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
class MatMulModule(nn.Module):
|
||||
|
||||
def forward(self, x1, x2):
|
||||
return torch.matmul(x1, x2)
|
||||
|
||||
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
[[8], [8]], # dot product
|
||||
[[4, 8], [8]], # mat-vec product
|
||||
[[4, 8], [8, 16]], # mat-mat product
|
||||
[[8], [8, 16]], # mat-mat product
|
||||
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
|
||||
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
|
||||
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
|
||||
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
|
||||
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
|
||||
])
|
||||
def test_matmul_node_handler(tensor_shapes):
|
||||
input_shape, other_shape = tensor_shapes
|
||||
|
||||
# get output shape
|
||||
x1 = torch.rand(*input_shape)
|
||||
x2 = torch.rand(*other_shape)
|
||||
output_shape = list(torch.matmul(x1, x2).shape)
|
||||
|
||||
# get matmul type
|
||||
matmul_type = get_matmul_type(x1.dim(), x2.dim())
|
||||
|
||||
model = MatMulModule()
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(mod_node)
|
||||
|
||||
# build handler
|
||||
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.logical_shape is not None
|
||||
assert op_data.data is not None
|
||||
|
||||
logical_input_shape = input_shape
|
||||
logical_other_shape = other_shape
|
||||
logical_output_shape = output_shape
|
||||
if matmul_type == MatMulType.MM and len(input_shape) == 1:
|
||||
logical_input_shape = [1] + input_shape
|
||||
elif matmul_type == MatMulType.BMM:
|
||||
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
|
||||
input_shape, other_shape, handler.transforms)
|
||||
else:
|
||||
logical_input_shape = input_shape
|
||||
|
||||
# check input operation data
|
||||
assert mapping['input'].name == "x1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
|
||||
|
||||
# check other operation data
|
||||
assert mapping['other'].name == "x2"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size(other_shape)
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
|
||||
|
||||
# check output
|
||||
assert mapping['output'].name == "matmul"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# ensure there is no duplicate strategy
|
||||
if matmul_type != MatMulType.BMM:
|
||||
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
|
||||
|
||||
if matmul_type == MatMulType.DOT:
|
||||
# dot product will produce a scaler
|
||||
# results should fulfill:
|
||||
# 1. the input and other operands have the same sharding spec
|
||||
# 2. the output has no sharding
|
||||
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
|
||||
assert len(output_sharding_spec.sharding_sequence) == 0
|
||||
elif matmul_type == MatMulType.MV:
|
||||
# matrix-vector product should fulfill
|
||||
# 1. the last dim of the input and other operands should have the same sharding
|
||||
# 2. the first dim of the input and other should have the same sharding
|
||||
# 3. the output should have only 1 dim
|
||||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert len(output_sharding_spec.sharding_sequence) == 1
|
||||
elif matmul_type == MatMulType.MM:
|
||||
# matrix-matrix multiplication should fulfil
|
||||
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
|
||||
# 2. the input's last dim and the first dim of the other should have the same sharding
|
||||
# 3. the last dim of the output and other should have the same sharding
|
||||
# 4. the input and output should have the same number of dims
|
||||
if len(input_shape) == 2:
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
|
||||
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
|
||||
elif matmul_type == MatMulType.BMM:
|
||||
# bmm should fulfil
|
||||
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
|
||||
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
|
||||
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
|
||||
if len(other_shape) > 1:
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
if len(input_shape) > 1:
|
||||
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
|
||||
if len(other_shape) > 2:
|
||||
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_matmul_node_handler()
|
Loading…
Reference in New Issue