mirror of https://github.com/hpcaitech/ColossalAI
735 lines
31 KiB
Python
735 lines
31 KiB
Python
import operator
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
|
from .operator_handler import OperatorHandler
|
|
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
|
from functools import reduce
|
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
|
from enum import Enum
|
|
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
|
from typing import List
|
|
|
|
__all__ = ['DotHandler']
|
|
|
|
|
|
class DotProductStrategyGenerator(StrategyGenerator):
|
|
"""
|
|
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
|
|
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
|
|
do not consider bias here.
|
|
"""
|
|
|
|
def validate(self, input, other):
|
|
assert input.dim() == 1 and other.dim() == 1
|
|
|
|
def no_split(self):
|
|
name = f'R = R dot R'
|
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_one_dim(self, mesh_dim):
|
|
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
|
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
|
|
|
def generate(self) -> List[IntermediateStrategy]:
|
|
strategy_list = []
|
|
|
|
# do not split dimensions for dot product
|
|
# R = R dot R
|
|
strategy_list.append(self.no_split())
|
|
|
|
# split two tensors in the same dimensions
|
|
# S = S dot S
|
|
strategy_list.append(self.split_one_dim(0))
|
|
strategy_list.append(self.split_one_dim(1))
|
|
|
|
return strategy_list
|
|
|
|
|
|
class MatVecStrategyGenerator(StrategyGenerator):
|
|
|
|
def validate(self, input, other) -> bool:
|
|
assert input.dim() > 1 and other.dim() == 1
|
|
|
|
def no_split(self):
|
|
name = "R = R x R"
|
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_input_batch(self, mesh_dim):
|
|
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def generate(self) -> List[IntermediateStrategy]:
|
|
strategy_list = []
|
|
|
|
# no split
|
|
strategy_list.append(self.no_split())
|
|
|
|
# split the batch dim for the first tensor only
|
|
strategy_list.append(self.split_input_batch(0))
|
|
strategy_list.append(self.split_input_batch(1))
|
|
|
|
return strategy_list
|
|
|
|
|
|
class MatMulStrategyGenerator(StrategyGenerator):
|
|
"""
|
|
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
|
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
|
|
|
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
|
|
|
Args:
|
|
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
|
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
|
"""
|
|
|
|
def __init__(self, is_linear: bool, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.is_linear = is_linear
|
|
|
|
# as the weight for the linear module is transposed, we can compute
|
|
# the correponding dimension indexfor convenience
|
|
if is_linear:
|
|
self.dim_q = 0
|
|
self.dim_p = 1
|
|
else:
|
|
self.dim_q = 1
|
|
self.dim_p = 0
|
|
|
|
def validate(self, input, other, bias) -> bool:
|
|
# make sure the second tensor is a 2D tensor
|
|
assert input.dim() > 0 and other.dim() == 2
|
|
|
|
# make sure bias is of the same dimension
|
|
if self.is_linear:
|
|
assert bias is None or bias.shape[-1] == other.shape[0]
|
|
else:
|
|
assert bias is None or bias.shape[-1] == other.shape[1]
|
|
|
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
|
# handle case SS = SR x RS
|
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
|
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
"other": {
|
|
self.dim_q: [mesh_dim_1]
|
|
},
|
|
"bias": {
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
-1: [mesh_dim_1]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
|
# handle the case SR = SS x SR
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0],
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"other": {
|
|
self.dim_p: [mesh_dim_1]
|
|
},
|
|
"bias": {},
|
|
"output": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
|
|
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
-1: [mesh_dim_0]
|
|
},
|
|
"other": {
|
|
self.dim_p: [mesh_dim_0],
|
|
self.dim_q: [mesh_dim_1]
|
|
},
|
|
"bias": {
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"output": {
|
|
-1: [mesh_dim_1]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def recompute_split_both_contract(self, mesh_dim):
|
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
-1: [mesh_dim]
|
|
},
|
|
"other": {
|
|
self.dim_p: [mesh_dim]
|
|
},
|
|
"bias": {},
|
|
"output": {},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
|
|
|
def split_rhs_space_only(self, mesh_dim):
|
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
|
dim_partition_dict = {
|
|
"input": {},
|
|
"other": {
|
|
self.dim_q: [mesh_dim]
|
|
},
|
|
"bias": {
|
|
-1: [mesh_dim]
|
|
},
|
|
"output": {
|
|
-1: [mesh_dim]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
|
|
|
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"other": {},
|
|
"bias": {},
|
|
"output": {
|
|
0: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
-1: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"other": {
|
|
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"bias": {},
|
|
"output": {},
|
|
}
|
|
return IntermediateStrategy(name=name,
|
|
dim_partition_dict=dim_partition_dict,
|
|
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
|
|
|
dim_partition_dict = {
|
|
"input": {},
|
|
"other": {
|
|
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"bias": {
|
|
-1: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"output": {
|
|
-1: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
|
|
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
|
"""
|
|
Generate sharding strategies for the batched matrix multiplication.
|
|
|
|
A batched matrix multiplication can be viewed as
|
|
[b, i, k] x [b, k, j] -> [b, i, j]
|
|
"""
|
|
|
|
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.is_torch_bmm = is_torch_bmm
|
|
|
|
def validate(self, input, other, bias) -> bool:
|
|
if self.is_torch_bmm:
|
|
assert input.shape == other.shape
|
|
assert input.dim() > 2
|
|
assert other.shape[-1] == bias.shape[0]
|
|
else:
|
|
# TODO: validate these inputs are broadcastable
|
|
pass
|
|
|
|
def split_one_batch_dim(self):
|
|
if 1 in self.device_mesh.mesh_shape:
|
|
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim]
|
|
},
|
|
"other": {
|
|
0: [mesh_dim]
|
|
},
|
|
"bias": {},
|
|
"output": {
|
|
0: [mesh_dim]
|
|
}
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
else:
|
|
return None
|
|
|
|
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0, mesh_dim_1]
|
|
},
|
|
"bias": {},
|
|
"output": {
|
|
0: [mesh_dim_0, mesh_dim_1]
|
|
}
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_one_batch_dim(self, mesh_dim):
|
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0],
|
|
-2: [mesh_dim_1]
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
"bias": {},
|
|
"output": {
|
|
0: mesh_dim_0,
|
|
-2: [mesh_dim_1]
|
|
}
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0],
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"bias": {
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
-1: [mesh_dim_1]
|
|
}
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
|
|
|
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
|
dim_partition_dict = {
|
|
"input": {
|
|
0: [mesh_dim_0],
|
|
-1: [mesh_dim_1]
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0],
|
|
-2: [mesh_dim_1]
|
|
},
|
|
"bias": {},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
-2: [mesh_dim_1]
|
|
}
|
|
}
|
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
|
|
|
def generate(self) -> List[IntermediateStrategy]:
|
|
strategy_list = []
|
|
|
|
# split only the batch dimension
|
|
# Sb = Sb x Sb
|
|
# can be None as it is only for 1D device mesh
|
|
strategy = self.split_one_batch_dim()
|
|
if strategy:
|
|
strategy_list.append(strategy)
|
|
|
|
# split batch dim of two inputs and the i dim of the first tensor
|
|
# SbSi = SbSi x Sb
|
|
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
|
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
|
|
|
# split batch dim of two inputs and the j of the second tensor
|
|
# SbSj = Sb x SbSj
|
|
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
|
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
|
|
|
# split batch dim of two inputs and the k dim of two inputs
|
|
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
|
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
|
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
|
|
|
# split two batch dim
|
|
strategy_list.append(self.split_two_batch_dim(0, 1))
|
|
strategy_list.append(self.split_two_batch_dim(1, 0))
|
|
|
|
return strategy_list
|
|
|
|
|
|
class DotHandler(OperatorHandler):
|
|
"""
|
|
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.input_data = self.predecessor_node[0]._meta_data
|
|
self.weight = self.module_named_parameters['weight']
|
|
self.output_data = self.node._meta_data
|
|
|
|
def _generate_compute_cost(self, input_shape, weight_shape):
|
|
# TODO: consider bias addition
|
|
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
|
return compute_cost
|
|
|
|
@exception_handler
|
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
|
# handle case SS = SR x RS
|
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
# linear layer weight is transposed during init
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute computation cost
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost
|
|
# no all-reduce required for this case
|
|
communication_cost = 0
|
|
|
|
# create and register strategy
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
|
# handle the case SR = SS x SR
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
# since weight of the linear layer is transposed
|
|
# the actual dim to be sharded is 1
|
|
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def recompute_split_both_contract(self, mesh_dim):
|
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_rhs_space_only(self, mesh_dim):
|
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
|
|
|
dim_partition_dict_for_input = {}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {1: [mesh_dim]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = 0
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
@exception_handler
|
|
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
|
|
|
dim_partition_dict_for_input = {}
|
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
|
|
|
# compute the computation cost of this strategy
|
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
|
|
|
# compute the memory cost of this strategy
|
|
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
|
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
|
|
|
# compute the communication cost of this strategy
|
|
communication_cost = 0
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_ouput,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=toatl_memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def register_strategy(self) -> StrategiesVector:
|
|
'''
|
|
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
|
|
|
|
Output:
|
|
|
|
'''
|
|
# SS = SR x RS
|
|
self.split_lhs_space_rhs_space(0, 1)
|
|
self.split_lhs_space_rhs_space(1, 0)
|
|
|
|
# SR = SS x SR
|
|
self.split_lhs_space_both_contract(0, 1)
|
|
self.split_lhs_space_both_contract(1, 0)
|
|
|
|
# RS = RS x SS
|
|
self.split_rhs_space_both_contract(0, 1)
|
|
self.split_rhs_space_both_contract(1, 0)
|
|
|
|
# RR= RS x SR
|
|
self.recompute_split_both_contract(0)
|
|
self.recompute_split_both_contract(1)
|
|
|
|
# RS = RR x RS
|
|
self.split_rhs_space_only(0)
|
|
self.split_rhs_space_only(1)
|
|
|
|
# S01R = S01R x RR
|
|
self.split_lhs_1st_dim_1d(0, 1)
|
|
|
|
# RR = RS01 x S01R
|
|
self.split_lhs_2nd_dim_1d(0, 1)
|
|
|
|
# RS01 = RR x RS01
|
|
self.split_rhs_2nd_dim_1d(0, 1)
|
|
|
|
return self.strategies_vector
|