[autoparallel] added matmul handler (#1763)

* [autoparallel] added matmul handler

* polish code
pull/1784/head
Frank Lee 2022-11-01 15:14:53 +08:00 committed by GitHub
parent 4df0194976
commit f3f19a5c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 725 additions and 28 deletions

View File

@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry'
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
]

View File

@ -0,0 +1,482 @@
import operator
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
BroadcastType,
get_broadcast_dim_info,
get_broadcast_shape,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
StrategyGenerator,
)
class MatMulType(Enum):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
BMM = 3
def get_matmul_type(input_dim: int, other_dim: int):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
elif input_dim in [1, 2] and other_dim == 2:
matmul_type = MatMulType.MM
elif input_dim == 2 and other_dim == 1:
matmul_type = MatMulType.MV
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
matmul_type = MatMulType.BMM
else:
raise ValueError(
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
)
return matmul_type
class BmmTransform(ABC):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@abstractmethod
def apply(self, shape_mapping: Dict[str, List[int]]):
pass
@abstractmethod
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
pass
class Padder(BmmTransform):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
# keep the padding dim, op_name -> padded_dim
self.padded_dim_mapping = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
dim_partition_list = [None] * len(tensor_shape)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim = self.padded_dim_mapping[key]
# compute the new dim partition
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
dim_partition_list[tensor_dim] = mesh_dims
dim_partition_list.pop(padded_dim)
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Broadcaster(BmmTransform):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
self.broadcast_dim_info = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
# broadcast the batch dim and record
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# remove sharding on the broadcast dim
def _remove_sharding_on_broadcast_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
if broadcast_type == BroadcastType.MULTIPLE:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Viewer(BmmTransform):
"""
Change the shape of the tensor from N-D to 3D
"""
def __init__(self) -> None:
self.batch_dims_before_view = None
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# get operation data
def _update_sharding_spec(key, strategy, physical_batch_dim):
"""
Map the logical batch dim to the physical batch dim
"""
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
num_batch_dim_before_view = len(self.batch_dims_before_view)
# enumerate all sharding strategies
strategies = []
for i in range(num_batch_dim_before_view):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
continue
return strategies
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
return input_shape, other_shape, output_shape
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# check which type of operation this matmul will call
self.input_meta_data = self.node.args[0]._meta_data
self.other_meta_data = self.node.args[1]._meta_data
self.output_meta_data = self.node._meta_data
input_dim = self.input_meta_data.dim()
other_dim = self.other_meta_data.dim()
self.matmul_type = get_matmul_type(input_dim, other_dim)
if self.matmul_type == MatMulType.BMM:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self.transforms = [Padder(), Broadcaster(), Viewer()]
else:
self.transforms = None
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
if self.matmul_type == MatMulType.BMM:
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.DOT:
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape_func = {
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
return op_data_mapping
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
# convert list to torch.Size
if input_logical_shape:
input_logical_shape = torch.Size(input_logical_shape)
if other_logical_shape:
other_logical_shape = torch.Size(other_logical_shape)
if output_logical_shape:
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return None, None, None
def _get_logical_shape_for_mm(self):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if self.input_meta_data.dim() == 1:
input_logical_shape = [1] + list(self.input_meta_data.shape)
input_logical_shape = torch.Size(input_logical_shape)
else:
input_logical_shape = None
return input_logical_shape, None, None
def _get_logical_shape_for_mv(self):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return None, None, None
def _get_logical_shape_for_bmm(self):
input_physical_shape = list(self.input_meta_data.shape)
other_physical_shape = list(self.other_meta_data.shape)
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
return strategy
elif self.matmul_type == MatMulType.MM:
if self.input_meta_data.dim() == 1:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
input_physical_shape = self.node.args[0]._meta_data.shape
dim_partition_dict = input_sharding_spec.dim_partition_dict
# remove the partitioning in the dim 0
if 0 in dim_partition_dict:
dim_partition_dict.pop(0, None)
# move the partitioning in dim 1 to dim 0
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
return strategy
else:
return strategy
elif self.matmul_type == MatMulType.BMM:
op_data_mapping = self.get_operation_data_mapping()
strategies = [strategy]
# recover the physical sharding spec
for transform in self.transforms[::-1]:
recovered_stragies = []
for strategy_ in strategies:
output = transform.recover(op_data_mapping, strategy_)
if isinstance(output, ShardingStrategy):
recovered_stragies.append(output)
elif isinstance(output, (list, tuple)):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
return strategies

View File

@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# do not split dimensions for dot product
@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# no split
@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
1: [mesh_dim_1]
},
"bias": {},
"output": {

View File

@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
"""
op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
return num_elements * size_per_elem_bytes
def generate(self) -> List[ShardingStrategy]:
"""

View File

@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return dims[::-1]
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
def get_broadcast_dim_info(logical_shape, physical_shape):
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape
physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict

View File

@ -1,6 +1,5 @@
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import torch
@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence

View File

@ -0,0 +1,166 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler,
MatMulType,
_get_bmm_logical_shape,
get_matmul_type,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
class MatMulModule(nn.Module):
def forward(self, x1, x2):
return torch.matmul(x1, x2)
@parameterize(
'tensor_shapes',
[
[[8], [8]], # dot product
[[4, 8], [8]], # mat-vec product
[[4, 8], [8, 16]], # mat-mat product
[[8], [8, 16]], # mat-mat product
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
])
def test_matmul_node_handler(tensor_shapes):
input_shape, other_shape = tensor_shapes
# get output shape
x1 = torch.rand(*input_shape)
x2 = torch.rand(*other_shape)
output_shape = list(torch.matmul(x1, x2).shape)
# get matmul type
matmul_type = get_matmul_type(x1.dim(), x2.dim())
model = MatMulModule()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(mod_node)
# build handler
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
logical_input_shape = input_shape
logical_other_shape = other_shape
logical_output_shape = output_shape
if matmul_type == MatMulType.MM and len(input_shape) == 1:
logical_input_shape = [1] + input_shape
elif matmul_type == MatMulType.BMM:
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
input_shape, other_shape, handler.transforms)
else:
logical_input_shape = input_shape
# check input operation data
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
# check other operation data
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size(other_shape)
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
# check output
assert mapping['output'].name == "matmul"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# ensure there is no duplicate strategy
if matmul_type != MatMulType.BMM:
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
assert len(output_sharding_spec.sharding_sequence) == 0
elif matmul_type == MatMulType.MV:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert len(output_sharding_spec.sharding_sequence) == 1
elif matmul_type == MatMulType.MM:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if len(input_shape) == 2:
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
elif matmul_type == MatMulType.BMM:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_matmul_node_handler()