mirror of https://github.com/hpcaitech/ColossalAI
55 lines
2.4 KiB
Python
55 lines
2.4 KiB
Python
|
class ShardingStrategy:
|
||
|
'''
|
||
|
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
||
|
and costs information using in solver.
|
||
|
|
||
|
Argument:
|
||
|
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||
|
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
||
|
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
||
|
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
||
|
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
||
|
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||
|
strategy.(default to None)
|
||
|
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||
|
'''
|
||
|
|
||
|
def __init__(self,
|
||
|
name,
|
||
|
output_sharding_spec,
|
||
|
compute_cost=0,
|
||
|
communication_cost=0,
|
||
|
memory_cost=0,
|
||
|
resharding_costs=None,
|
||
|
input_shardings=None):
|
||
|
self.name = name
|
||
|
self.output_sharding_spec = output_sharding_spec
|
||
|
self.compute_cost = compute_cost
|
||
|
self.communication_cost = communication_cost
|
||
|
self.memory_cost = memory_cost
|
||
|
self.resharding_costs = resharding_costs
|
||
|
self.input_shardings = input_shardings
|
||
|
|
||
|
|
||
|
class StrategiesVector:
|
||
|
'''
|
||
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||
|
strategies of the node.
|
||
|
|
||
|
Argument:
|
||
|
node(Node): node to build corresponding strategies_vector.
|
||
|
in_nodes(List[Node]): input nodes in the argument list of the node.
|
||
|
following_nodes(List[Node]): the nodes take the target node as their argument.
|
||
|
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
|
||
|
self.node = node
|
||
|
self.in_nodes = in_nodes
|
||
|
self.following_nodes = following_nodes
|
||
|
self.strategies = strategies
|
||
|
|
||
|
def check_merge(self):
|
||
|
pass
|