2022-08-19 07:51:54 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
2022-08-30 08:32:09 +00:00
|
|
|
from typing import Dict, List, Union, Tuple
|
2022-08-23 06:23:08 +00:00
|
|
|
from torch.fx.node import Node
|
2022-08-30 08:32:09 +00:00
|
|
|
from .constants import *
|
2022-08-23 06:23:08 +00:00
|
|
|
|
|
|
|
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
2022-08-19 07:51:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2022-08-19 06:57:23 +00:00
|
|
|
class ShardingStrategy:
|
|
|
|
'''
|
|
|
|
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
|
|
|
and costs information using in solver.
|
|
|
|
|
|
|
|
Argument:
|
|
|
|
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
|
|
|
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
|
|
|
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
|
|
|
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
|
|
|
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
|
|
|
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
|
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
|
|
strategy.(default to None)
|
|
|
|
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
|
|
|
'''
|
|
|
|
|
2022-08-19 07:51:54 +00:00
|
|
|
name: str
|
2022-08-30 08:32:09 +00:00
|
|
|
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
|
|
|
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
2022-08-19 07:51:54 +00:00
|
|
|
compute_cost: float = 0.
|
|
|
|
communication_cost: float = 0.
|
|
|
|
memory_cost: float = 0.
|
2022-08-30 08:32:09 +00:00
|
|
|
resharding_costs: Dict[Node, List[float]] = None
|
|
|
|
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
|
|
|
# Therefore, we could process them at the specific op(operator.getitem)
|
|
|
|
input_shardings: List[ShardingSpec] = None
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
class StrategiesVector(list):
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
|
|
|
strategies of the node.
|
|
|
|
|
|
|
|
Argument:
|
2022-08-23 06:23:08 +00:00
|
|
|
node (Node): node for which the list of sharding strategies are generated.
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def __init__(self, node: Node):
|
|
|
|
super().__init__()
|
2022-08-19 06:57:23 +00:00
|
|
|
self.node = node
|
2022-08-23 06:23:08 +00:00
|
|
|
# fetch its input and output nodes
|
2022-08-30 08:32:09 +00:00
|
|
|
# TODO: placeholder input nodes
|
2022-08-23 06:23:08 +00:00
|
|
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
2022-08-25 09:19:59 +00:00
|
|
|
self.successor_nodes = list(node.users.keys())
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
def check_merge(self):
|
2022-08-30 08:32:09 +00:00
|
|
|
merge_label = False
|
|
|
|
if self.node.op == 'call_module':
|
|
|
|
target = self.node.target
|
|
|
|
root_module = self.node.graph.owning_module
|
|
|
|
submod = root_module.get_submodule(target)
|
|
|
|
submod_type = type(submod)
|
2022-09-14 02:25:45 +00:00
|
|
|
# merge elementwise module node into source nodes
|
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if submod_type in ELEMENTWISE_MODULE_OP:
|
|
|
|
merge_label = True
|
|
|
|
|
|
|
|
if self.node.op == 'call_function':
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if self.node.target in ELEMENTWISE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
|
|
|
if self.node.target in RESHAPE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-08-30 08:32:09 +00:00
|
|
|
|
|
|
|
return merge_label
|