diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py new file mode 100644 index 000000000..a67ac1c3f --- /dev/null +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -0,0 +1,131 @@ +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from typing import List +from torch.fx.node import Node + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True): + self.leaf_strategies = leaf_strategies + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.simplify = simplify + self._build_cost_graph() + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + node_pair = (src_node, dst_node) + src_index = strategies_vector.predecessor_nodes.index(src_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.stategy_vector)): + edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j] + self.edge_costs[node_pair] = edge_cost + # add parents and children attribute to node + setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) + setattr(dst_node, 'children', strategies_vector.successor_nodes) + + if self.simplify and strategies_vector.check_merge(): + for following_node in strategies_vector.successor_nodes: + self.merge_pair.append((dst_node, following_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge src_node into dst_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + src_node_index = dst_node.parents.index(src_node) + # build merge_map + merge_map = {} + for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): + resharding_costs = strategy.resharding_costs + resharding_cost_for_src = resharding_costs[src_node_index] + lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src)) + merge_map[dst_strate_index] = lowest_cost_index + + # extra_node_cost for dst node + extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] + for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): + target_strate_index = merge_map[dst_strate_index] + extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][ + target_strate_index] + if src_node in extra_node_costs: + extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index] + + # connect dst node and parents of src node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + node_pair_to_remove = [(src_node, dst_node)] + for parent_node in src_node.parents: + if parent_node not in dst_node.parents: + dst_node.parents.append(parent) + if dst_node not in parent_node.children: + parent_node.children.append(dst_node) + # remove src node from cost graph when src node has no consumer. + if len(src_node.children) == 0: + parent_node.children.remove(src_node) + node_pair = (parent_node, src_node) + self.edge_costs.pop(node_pair) + + # add new node pair to cost graph + for parent_node in src_node.parents: + new_node_pair = (parent_node, dst_node) + old_node_pair = (parent_node, src_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[dst_node]): + for j in range(self.node_lens[parent_node]): + src_strate_index = merge_map[i] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)] + self.edge_costs[new_node_pair] = edge_cost + + def simplify_graph(self): + if not self.simplify: + return + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 63d17e6cb..1cacc9324 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -84,6 +84,7 @@ class OperatorHandler(ABC): for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): resharding_costs[input_node] = [] for strategy in input_node.strategies_vector: - _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec) + _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( + strategy.output_sharding_spec, input_spec) resharding_costs[input_node].append(resharding_cost) return resharding_cost diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index b6eb2e220..870d7e8dd 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -47,7 +47,7 @@ class StrategiesVector(list): self.node = node # fetch its input and output nodes self.predecessor_nodes = list(node._input_nodes.keys()) - self.successor_ndoes = list(node.users.keys()) + self.successor_nodes = list(node.users.keys()) def check_merge(self): pass diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index f15621477..00c434e03 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -15,7 +15,7 @@ def torch_matmul(input, other, *, out=None): shape = (input.size(0), other.size(1)) elif d1 == 1 and d2 == 2: shape = (other.size(1),) - elif d1 == 2 and d1 == 1: + elif d1 == 2 and d2 == 1: shape = (input.size(0),) else: max_length = max(input.dim(), other.dim()) diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 13ec9a16f..3cda3bd80 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -70,7 +70,9 @@ def test_conv_handler(): sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=entire_shape, sharding_sequence=sharding_sequence) - strategies_vector_for_input.append(sharding_spec) + strategy_name = str(sharding_spec.sharding_sequence) + sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) + strategies_vector_for_input.append(sharding_strategy) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) # generate conv strategy diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py index f85546b15..4cc41178d 100644 --- a/tests/test_auto_parallel/test_dot_handler.py +++ b/tests/test_auto_parallel/test_dot_handler.py @@ -69,7 +69,9 @@ def test_dot_handler(): sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=entire_shape, sharding_sequence=sharding_sequence) - strategies_vector_for_input.append(sharding_spec) + strategy_name = str(sharding_spec.sharding_sequence) + sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) + strategies_vector_for_input.append(sharding_strategy) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) # generate dot strategy