From 44c866a3e3af2c0020458d08e63ce5e5a303d448 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:18:19 +0800 Subject: [PATCH] [autoparallel] change the merge node logic (#1533) --- colossalai/auto_parallel/solver/cost_graph.py | 96 ++++++++++++------- colossalai/tensor/shape_consistency.py | 3 +- tests/test_auto_parallel/test_cost_graph.py | 15 +-- 3 files changed, 71 insertions(+), 43 deletions(-) diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index 94691397d..220ab54a3 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -1,5 +1,6 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from typing import List +import math from torch.fx.node import Node @@ -23,6 +24,7 @@ class CostGraph: self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} # extra_node_costs will store the extra costs introduced by merging nodes self.extra_node_costs = {} + self.following_dict = {} self.simplify = simplify self._build_cost_graph() @@ -50,15 +52,15 @@ class CostGraph: setattr(dst_node, 'children', strategies_vector.successor_nodes) if self.simplify and strategies_vector.check_merge(): - for following_node in strategies_vector.successor_nodes: - self.merge_pair.append((dst_node, following_node)) + for followed_node in strategies_vector.predecessor_nodes: + self.merge_pair.append((followed_node, dst_node)) def get_edge_cost(self, src_node, dst_node): return self.edge_costs[(src_node, dst_node)] def merge_node(self, src_node, dst_node): ''' - To merge src_node into dst_node, we need to do it in following steps: + To merge dst_node into src_node, we need to do it in following steps: 1. For each strategy in dst_node, we need to pick an appropriate strategy of src_node to merge, it is important because the logical resharding costs @@ -81,52 +83,76 @@ class CostGraph: src_node_index = dst_node.parents.index(src_node) # build merge_map merge_map = {} - for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): - resharding_costs = strategy.resharding_costs - resharding_cost_for_src = resharding_costs[src_node] - lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src)) - merge_map[dst_strate_index] = lowest_cost_index + for src_index, strategy in enumerate(src_node.strategies_vector): + min_cost = math.inf + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost = dst_strategy.resharding_costs[src_node][src_index] + if resharding_cost < min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index - # extra_node_cost for dst node - self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] - for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): - target_strate_index = merge_map[dst_strate_index] - self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][ - target_strate_index] - if src_node in self.extra_node_costs: - self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][ - target_strate_index] + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index] + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] # add new node pair to cost graph - for parent_node in src_node.parents: - new_node_pair = (parent_node, dst_node) - old_node_pair = (parent_node, src_node) + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) if new_node_pair in self.edge_costs: continue edge_cost = {} - for i in range(self.node_lens[dst_node]): - for j in range(self.node_lens[parent_node]): - src_strate_index = merge_map[i] - edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)] - self.edge_costs[new_node_pair] = edge_cost + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + # dst_strategy = dst_node.strategies_vector[dst_strate_index] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] - # connect dst node and parents of src node + # connect src node and children of dst node dst_node.parents.remove(src_node) src_node.children.remove(dst_node) self.edge_costs.pop((src_node, dst_node)) - for parent_node in src_node.parents: - if parent_node not in dst_node.parents: - dst_node.parents.append(parent_node) - if dst_node not in parent_node.children: - parent_node.children.append(dst_node) - # remove src node from cost graph when src node has no consumer. - if len(src_node.children) == 0: - parent_node.children.remove(src_node) - node_pair = (parent_node, src_node) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) def simplify_graph(self): if not self.simplify: return + self.merge_pair.reverse() for (src_node, dst_node) in self.merge_pair: self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 5e7ec68f3..d411918e1 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -82,7 +82,8 @@ class CommSpec: if self.comm_pattern == CollectiveCommPattern.ALLREDUCE: return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.SHARD: - return 0 + # give a tiny cost to shard + return 10 raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.") def covert_spec_to_action(self, tensor): diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_cost_graph.py index bb3e05087..30e3ece3b 100644 --- a/tests/test_auto_parallel/test_cost_graph.py +++ b/tests/test_auto_parallel/test_cost_graph.py @@ -58,11 +58,11 @@ def test_cost_graph(): strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor.build_strategies_and_cost() - # (x, mul): {(0, 0): 0} - # (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0} - # (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0} - # (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0} - # (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002} + # (x, mul):{(0, 0): 0} + # (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002} + # (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0} + # (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0} + # (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002} cost_graph = CostGraph(strategies_constructor.leaf_strategies) # construct all node pairs @@ -83,9 +83,10 @@ def test_cost_graph(): # add (x, conv) and (conv, output) into check node pairs merged_node_pairs.append((node_list[0], node_list[2])) merged_node_pairs.append((node_list[2], node_list[-1])) - # (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0} - # (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf} + # (conv1, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 246019.30000000002, (5, 0): 246019.30000000002, (6, 0): 123009.1, (7, 0): 123009.1, (8, 0): 123009.1, (9, 0): 123009.1, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): 246019.30000000002, (14, 0): 246019.30000000002} + # (x, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002} cost_graph.simplify_graph() + for node_pair in all_node_pairs: if node_pair in merged_node_pairs: assert node_pair in cost_graph.edge_costs