From c27e701cb2e1b19d5d58de4e5cb10edc388bd916 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 29 Sep 2022 10:43:25 +0800 Subject: [PATCH] [autoparallel] remove no strategy nodes (#1652) * [autoparallel] remove no strategy nodes * fix none object iteration issue --- colossalai/auto_parallel/solver/cost_graph.py | 19 +++- .../auto_parallel/solver/sharding_strategy.py | 2 + colossalai/auto_parallel/solver/solver.py | 9 +- .../solver/strategies_constructor.py | 90 ++++++++++++------- 4 files changed, 84 insertions(+), 36 deletions(-) diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index a4ec6c485..e491e79fb 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -1,6 +1,7 @@ from typing import List import math from torch.fx.node import Node +from colossalai.auto_parallel.solver.constants import INFINITY_COST class CostGraph: @@ -19,6 +20,7 @@ class CostGraph: def __init__(self, leaf_strategies, simplify=True): self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] # stores number of strategies in each node self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} # extra_node_costs will store the extra costs introduced by merging nodes @@ -27,6 +29,15 @@ class CostGraph: self.simplify = simplify self._build_cost_graph() + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + def _build_cost_graph(self): ''' This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be @@ -39,6 +50,8 @@ class CostGraph: # build edge_cost dst_node = strategies_vector.node for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue node_pair = (src_node, dst_node) # src_index = strategies_vector.predecessor_nodes.index(src_node) edge_cost = {} @@ -49,6 +62,8 @@ class CostGraph: # add parents and children attribute to node setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) setattr(dst_node, 'children', strategies_vector.successor_nodes) + self._remove_invalid_node(dst_node, 'parents') + self._remove_invalid_node(dst_node, 'children') if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: @@ -83,11 +98,11 @@ class CostGraph: # build merge_map merge_map = {} for src_index, strategy in enumerate(src_node.strategies_vector): - min_cost = math.inf + min_cost = INFINITY_COST lowest_cost_index = -1 for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): resharding_cost = dst_strategy.resharding_costs[src_node][src_index] - if resharding_cost < min_cost: + if resharding_cost <= min_cost: min_cost = resharding_cost lowest_cost_index = dst_index merge_map[src_index] = lowest_cost_index diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index e73a7281e..b81c25ffd 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -182,6 +182,8 @@ class StrategiesVector(list): # fetch its input and output nodes # TODO: placeholder input nodes self.predecessor_nodes = list(node._input_nodes.keys()) + if self.node.op == 'output': + self.predecessor_nodes = list(node._input_nodes.keys())[:1] self.successor_nodes = list(node.users.keys()) def check_merge(self): diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/solver/solver.py index 50f5c696f..8ca756c5e 100644 --- a/colossalai/auto_parallel/solver/solver.py +++ b/colossalai/auto_parallel/solver/solver.py @@ -45,8 +45,8 @@ class Solver: self.strategies_constructor = strategies_constructor self.cost_graph = cost_graph self.graph_analyser = graph_analyser - self.nodes = list(self.graph.nodes) self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] self.strategy_map = self.strategies_constructor.strategy_map self.memory_budget = memory_budget self.solution_numbers = solution_numbers @@ -67,7 +67,7 @@ class Solver: Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged node. ''' - for node_index, node in enumerate(self.graph.nodes): + for node_index, node in enumerate(self.nodes): if node.strategies_vector.check_merge(): # the merged node has only one input, and its strategies follow the input sharding strategy input_strategies_vector = node.args[0].strategies_vector @@ -297,7 +297,8 @@ class Solver: num_edges += 1 e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) assert len(e[idx]) == len(r[idx]) - + for element in s: + assert len(element) > 0 # 2. Set initial value ###################################### # set a initial value for warm start # @@ -317,12 +318,14 @@ class Solver: ################################################################### obj = 0 for i in range(node_nums): + assert len(s[i]) == len(c[i]) obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) ############################################# # computing the edge cost(resharding cost) # ############################################# for i in range(len(E)): + assert len(e[i]) == len(r[i]) obj += lpDot(e[i], r[i]) prob += obj diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 6eb843eba..fe0adc0a4 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -214,6 +214,21 @@ class StrategiesConstructor: linear_handler = DotHandler(node, self.device_mesh, strategies_vector) linear_handler.register_strategy() + # where function + elif target == torch.where: + if input_nodes_len == 1: + # both of x and y are scalar + pass + + elif input_nodes_len == 2: + # one of x or y is type of scalar + pass + + else: + # general case + where_handler = WhereHandler(node, self.device_mesh, strategies_vector) + where_handler.register_strategy() + # reshape function elif target in RESHAPE_FUNC_OP: # use ReshapeHandler to create sharding strategies for rehsape node @@ -222,9 +237,8 @@ class StrategiesConstructor: # element-wise function elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): - if isinstance(node._meta_data, torch.Tensor): - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() # bcast op elif target in BCAST_FUNC_OP: @@ -291,32 +305,34 @@ class StrategiesConstructor: elif target == operator.getitem: index = node.args[1] input_tensor_node = strategies_vector.predecessor_nodes[0] - if isinstance(input_tensor_node, torch.Tensor): - for strategy in input_tensor_node.strategies_vector: + for strategy in input_tensor_node.strategies_vector: + if isinstance(strategy.output_sharding_spec, ShardingSpec): + input_sharding_spec = strategy.output_sharding_spec + else: input_sharding_spec = strategy.output_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' - dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) - entire_shape_output = deepcopy(input_sharding_spec.entire_shape) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_tensor_node] = [ - cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] - ] - sharding_strategy = ShardingStrategy( - name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_tensor_node.output_sharding_spec]) - strategies_vector.append(sharding_strategy) + assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' + dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) + entire_shape_output = deepcopy(input_sharding_spec.entire_shape) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec], + index=index) + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_tensor_node] = [ + cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[strategy.output_sharding_spec]) + strategies_vector.append(sharding_strategy) # torch.arange function elif target == torch.arange: @@ -334,8 +350,7 @@ class StrategiesConstructor: strategies_vector.append(sharding_strategy) # op list to be processed to support gpt2 - elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax, - torch.nn.functional.softmax, torch.pow, torch.tanh): + elif target in (builtins.getattr, operator.le, torch.addmm): pass # other function else: @@ -344,7 +359,7 @@ class StrategiesConstructor: # call_method node if node.op == 'call_method': method = getattr(node.args[0]._meta_data.__class__, node.target) - if method in (torch.Tensor.size, torch.Tensor.contiguous): + if method in (torch.Tensor.size,): pass elif method in ELEMENTWISE_METHOD_OP: unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) @@ -400,6 +415,18 @@ class StrategiesConstructor: self.strategy_map[node] = strategies_vector + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) + + class StrategiesConstructor_V2: """ StrategiesConstructor is used to construct the parallelization plan for the model execution. @@ -482,3 +509,4 @@ class StrategiesConstructor_V2: setattr(node, 'strategies_vector', strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector +