[autoparallel] added all non-bcast matmul strategies (#1603)

pull/1606/head
Frank Lee 2022-09-16 10:47:32 +08:00 committed by GitHub
parent db98b695b2
commit 3abf98a633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 251 additions and 3 deletions

View File

@ -13,9 +13,238 @@ from typing import List
__all__ = ['DotHandler']
class DotProductStrategyGenerator(StrategyGenerator):
"""
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
do not consider bias here.
"""
def validate(self, input, other):
assert input.dim() == 1 and other.dim() == 1
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_dim(self, mesh_dim):
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(StrategyGenerator):
def validate(self, input, other) -> bool:
assert input.dim() > 1 and other.dim() == 1
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class MatMulStrategyGenerator(StrategyGenerator):
# TODO: to be implmented
pass
"""
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed.
"""
def __init__(self, is_linear: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_linear = is_linear
# as the weight for the linear module is transposed, we can compute
# the correponding dimension indexfor convenience
if is_linear:
self.dim_q = 0
self.dim_p = 1
else:
self.dim_q = 1
self.dim_p = 0
def validate(self, input, other, bias) -> bool:
# make sure the second tensor is a 2D tensor
assert input.dim() > 0 and other.dim() == 2
# make sure bias is of the same dimension
if self.is_linear:
assert bias is None or bias.shape[-1] == other.shape[0]
else:
assert bias is None or bias.shape[-1] == other.shape[1]
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0]
},
"other": {
self.dim_p: [mesh_dim_0],
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim]
},
"other": {
self.dim_p: [mesh_dim]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name,
dim_partition_dict=dim_partition_dict,
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
class BatchedMatMulStrategyGenerator(StrategyGenerator):
@ -30,6 +259,15 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
super().__init__(*args, **kwargs)
self.is_torch_bmm = is_torch_bmm
def validate(self, input, other, bias) -> bool:
if self.is_torch_bmm:
assert input.shape == other.shape
assert input.dim() > 2
assert other.shape[-1] == bias.shape[0]
else:
# TODO: validate these inputs are broadcastable
pass
def split_one_batch_dim(self):
if 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)

View File

@ -32,4 +32,14 @@ class StrategyGenerator(ABC):
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
pass
"""
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass