mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added all non-bcast matmul strategies (#1603)
parent
db98b695b2
commit
3abf98a633
|
@ -13,9 +13,238 @@ from typing import List
|
|||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
class DotProductStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
|
||||
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
|
||||
do not consider bias here.
|
||||
"""
|
||||
|
||||
def validate(self, input, other):
|
||||
assert input.dim() == 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = f'R = R dot R'
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_dim(self, mesh_dim):
|
||||
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# do not split dimensions for dot product
|
||||
# R = R dot R
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split two tensors in the same dimensions
|
||||
# S = S dot S
|
||||
strategy_list.append(self.split_one_dim(0))
|
||||
strategy_list.append(self.split_one_dim(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatVecStrategyGenerator(StrategyGenerator):
|
||||
|
||||
def validate(self, input, other) -> bool:
|
||||
assert input.dim() > 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = "R = R x R"
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_input_batch(self, mesh_dim):
|
||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# no split
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split the batch dim for the first tensor only
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
# TODO: to be implmented
|
||||
pass
|
||||
"""
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||
|
||||
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||
|
||||
Args:
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||
"""
|
||||
|
||||
def __init__(self, is_linear: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_linear = is_linear
|
||||
|
||||
# as the weight for the linear module is transposed, we can compute
|
||||
# the correponding dimension indexfor convenience
|
||||
if is_linear:
|
||||
self.dim_q = 0
|
||||
self.dim_p = 1
|
||||
else:
|
||||
self.dim_q = 1
|
||||
self.dim_p = 0
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
# make sure the second tensor is a 2D tensor
|
||||
assert input.dim() > 0 and other.dim() == 2
|
||||
|
||||
# make sure bias is of the same dimension
|
||||
if self.is_linear:
|
||||
assert bias is None or bias.shape[-1] == other.shape[0]
|
||||
else:
|
||||
assert bias is None or bias.shape[-1] == other.shape[1]
|
||||
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0],
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name,
|
||||
dim_partition_dict=dim_partition_dict,
|
||||
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
|
||||
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
|
||||
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||
|
@ -30,6 +259,15 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.is_torch_bmm = is_torch_bmm
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
if self.is_torch_bmm:
|
||||
assert input.shape == other.shape
|
||||
assert input.dim() > 2
|
||||
assert other.shape[-1] == bias.shape[0]
|
||||
else:
|
||||
# TODO: validate these inputs are broadcastable
|
||||
pass
|
||||
|
||||
def split_one_batch_dim(self):
|
||||
if 1 in self.device_mesh.mesh_shape:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
|
|
|
@ -32,4 +32,14 @@ class StrategyGenerator(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
pass
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, *args, **kwargs) -> bool:
|
||||
"""
|
||||
Validate if the operands are of desired shape.
|
||||
If True, means this generator can be used for the current operation.
|
||||
"""
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue