diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index 9cdc984cb..378a14d03 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -95,7 +95,8 @@ def exception_handler(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: - func(*args, **kwargs) + rst = func(*args, **kwargs) + return rst except AssertionError as e: warnings.warn(f'{e}') diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index e491e79fb..a5f418be4 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -170,3 +170,188 @@ class CostGraph: for dst, src in self.following_dict.items(): reindexing_following_dict[dst] = self._reindexing_src(src) self.following_dict = reindexing_following_dict + + +class CostGraph_V2: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True, forward_only=False): + self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self.simplify = simplify + self.forward_only = forward_only + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + # src_index = strategies_vector.predecessor_nodes.index(src_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.strategies_vector)): + if strategies_vector[i].resharding_costs is None: + print(strategies_vector.node.name) + assert False + resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j] + if self.forward_only: + edge_cost[(j, i)] = resharding_cost_item.fwd + else: + edge_cost[(j, i)] = resharding_cost_item.total + self.edge_costs[node_pair] = edge_cost + # add parents and children attribute to node + setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) + setattr(dst_node, 'children', strategies_vector.successor_nodes) + self._remove_invalid_node(dst_node, 'parents') + self._remove_invalid_node(dst_node, 'children') + + if self.simplify and strategies_vector.check_merge(): + for followed_node in strategies_vector.predecessor_nodes: + self.merge_pair.append((followed_node, dst_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge dst_node into src_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + src_node_index = dst_node.parents.index(src_node) + # build merge_map + merge_map = {} + for src_index, strategy in enumerate(src_node.strategies_vector): + min_cost = INFINITY_COST + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost = resharding_cost_item.fwd + else: + resharding_cost = resharding_cost_item.total + if resharding_cost <= min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index + + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + resharding_cost_item = target_strategy.resharding_costs[src_node][src_index] + if self.forward_only: + resharding_cost_to_add = resharding_cost_item.fwd + else: + resharding_cost_to_add = resharding_cost_item.total + self.extra_node_costs[src_node][src_index] += resharding_cost_to_add + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] + + # add new node pair to cost graph + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + # dst_strategy = dst_node.strategies_vector[dst_strate_index] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] + + # connect src node and children of dst node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) + self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) + + def simplify_graph(self): + if not self.simplify: + return + self.merge_pair.reverse() + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index ab0cf58f5..9c7e2e595 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -9,9 +9,16 @@ from .unary_elementwise_handler import UnaryElementwiseHandler from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler from .layer_norm_handler_v2 import LayerNormModuleHandler from .batch_norm_handler_v2 import BatchNormModuleHandler +from .conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler +from .where_handler_v2 import WhereHandler +from .unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2 +from .reshape_handler_v2 import ReshapeHandler_V2 +from .placeholder_handler import PlacehodlerHandler +from .output_handler import OuputHandler __all__ = [ 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', 'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler', - 'LayerNormModuleHandler', 'BatchNormModuleHandler' + 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', + 'UnaryElementwiseHandler_V2', 'ReshapeHandler_V2', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler' ] diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py index 69a96c8ed..7085c3d2b 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py @@ -40,7 +40,7 @@ class ConvModuleHandler(ModuleHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.named_parameters['bias'] is not None: + if "bias" in self.named_parameters: physical_bias_operand = OperationData(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']) @@ -53,7 +53,6 @@ class ConvModuleHandler(ModuleHandler): """ for op_data, sharding_spec in strategy.input_sharding_specs.items(): if op_data.name == "weight": - assert op_data.logical_shape != op_data.data.shape dim_partition_dict = sharding_spec.dim_partition_dict # switch first and second dim of the conv module weight diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py index 977a4c94a..76ce1a766 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py @@ -6,12 +6,13 @@ from typing import List, Dict from .registry import operator_registry import operator -__all__ = ['ReshapeHandler'] +__all__ = ['ReshapeHandler_V2'] @operator_registry.register(torch.reshape) +@operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.permute) -class ReshapeHandler(NodeHandler): +class ReshapeHandler_V2(NodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py index 7ba71b00b..75b59f827 100644 --- a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py @@ -6,12 +6,12 @@ from typing import List, Dict from .registry import operator_registry import operator -__all__ = ['UnaryElementwiseHandler'] +__all__ = ['UnaryElementwiseHandler_V2'] @operator_registry.register(torch.abs) @operator_registry.register(torch.nn.ReLU) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler_V2(NodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/solver/solver.py index 6cd1e26c8..97674c088 100644 --- a/colossalai/auto_parallel/solver/solver.py +++ b/colossalai/auto_parallel/solver/solver.py @@ -465,3 +465,464 @@ class Solver: ret_list.append(ret) return ret_list + + +class Solver_V2: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + self.forward_only = forward_only + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + self.liveness_list = self.graph_analyser.liveness_analysis() + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = None + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_cost_item = strategy.compute_cost + communication_cost_item = strategy.communication_cost + memory_cost_item = strategy.memory_cost + + if self.forward_only: + origin_communication_cost = communication_cost_item.fwd + compute_cost = compute_cost_item.fwd + memory_cost = memory_cost_item.fwd + else: + origin_communication_cost = communication_cost_item.total + compute_cost = compute_cost_item.total + memory_cost = memory_cost_item.total + + compute_costs.append(compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(origin_communication_cost) + memory_costs.append(memory_cost) + # if isinstance(memory_cost, tuple): + # memory_costs.append(memory_cost[0]) + # else: + # memory_costs.append(memory_cost) + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + for liveness_stage in liveness_set: + mem = 0 + for live_variable in liveness_stage.unique_live_vars: + node_index = self.node_index_dict[live_variable.node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + verbose = True + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + # self._recover_merged_node_strategy() + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return self.last_s_val, e_val, self.last_objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index fe0adc0a4..0da540cae 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,10 +1,13 @@ from torch.fx import Graph, Node from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy_V2 from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.auto_parallel.solver.op_handler.registry import operator_registry +from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler from .options import SolverOptions from . import ShardingStrategy, StrategiesVector from .op_handler import * @@ -414,7 +417,6 @@ class StrategiesConstructor: self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector - # remove no strategy nodes remove_list = [] for strategies_vector in self.leaf_strategies: @@ -456,6 +458,10 @@ class StrategiesConstructor_V2: name_checklist = [] remove_list = [] for strategy in strategies_vector: + if strategy is None: + print(strategies_vector.node.name) + print(strategies_vector) + assert False if strategy.name not in name_checklist: name_checklist.append(strategy.name) else: @@ -469,16 +475,32 @@ class StrategiesConstructor_V2: """ for node in self.nodes: strategies_vector = StrategiesVector(node) - # placeholder node if node.op == 'placeholder': - # TODO: implement placeholder node handler - pass + placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector) + placeholder_handler.register_strategy() # get_attr node - elif node.op == 'get_attr': - # TODO: implement getattr node handler - pass + if node.op == 'get_attr': + # Same as placeholder nodes, if solver_options.fast is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the get_attr node. + # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. + if self.solver_options.fast: + # create sharding strategy for get_attr + name = 'Replica Attribute' + dim_partition_dict = {} + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_attribute) + + # # get_attr node + # elif node.op == 'get_attr': + # # TODO: implement getattr node handler + # pass # call_module node elif node.op == 'call_module': @@ -502,11 +524,13 @@ class StrategiesConstructor_V2: # output node elif node.op == 'output': - # TODO: implement output node handler - pass + output_handler = OuputHandler(node, self.device_mesh, strategies_vector) + output_handler.register_strategy() + if len(strategies_vector) <= 0: + print(node.name) + assert len(strategies_vector) > 0 self.remove_duplicated_strategy(strategies_vector) setattr(node, 'strategies_vector', strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector - diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index e7ecbb58c..a71b0e03e 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -8,10 +8,14 @@ from .layer_norm_generator import LayerNormGenerator from .where_generator import WhereGenerator from .reshape_generator import ReshapeGenerator from .normal_pooling_generator import NormalPoolStrategyGenerator +from .placeholder_generator import PlaceholderGenerator +from .output_generator import OutputGenerator + __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', - 'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator' + 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', + 'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py index a599aca66..58c76cc96 100644 --- a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py @@ -5,6 +5,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator_V2 from typing import List from .._utils import exception_handler +import warnings import copy @@ -100,6 +101,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @exception_handler def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -146,6 +148,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' @@ -182,6 +185,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -228,6 +232,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -267,6 +272,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_input_in_channel_weight_in_channel(self, mesh_dim_0): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' @@ -297,6 +303,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' @@ -329,6 +336,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def non_split(self): name = f'RR = RR x RR' @@ -347,6 +355,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}) + @exception_handler def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -384,6 +393,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_mapping = { @@ -413,6 +423,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @exception_handler def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_mapping = { @@ -482,10 +493,20 @@ class ConvStrategyGenerator(StrategyGenerator_V2): # RS01 = RR x RS01 strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) + rm_list = [strategy for strategy in strategies if strategy is None] + for rm_element in rm_list: + strategies.remove(rm_element) + illegal_strategy_list = [] # update mete info on cost for strategy in strategies: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) + try: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + except AssertionError as e: + illegal_strategy_list.append(strategy) + warnings.warn(f'{e}') + for strategy in illegal_strategy_list: + strategies.remove(strategy) return strategies diff --git a/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py index 3b03c7e91..c0dd02722 100644 --- a/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py @@ -5,8 +5,10 @@ from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +import pytest +@pytest.mark.skip("for higher testing speed") def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py index 758337ef0..8ae352778 100644 --- a/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler -from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler +from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler_V2 from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -48,9 +48,9 @@ def test_reshape_handler(): strategies_vector=conv_strategies_vector) conv_handler.register_strategy(compute_resharding_cost=False) setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = ReshapeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + reshape_handler = ReshapeHandler_V2(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) reshape_handler.register_strategy(compute_resharding_cost=False) diff --git a/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py index 62265f6cb..7d8f6f10b 100644 --- a/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler +from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2 from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -50,9 +50,9 @@ def test_elementwise_handler(): strategies_vector=conv_strategies_vector) conv_handler.register_strategy(compute_resharding_cost=False) setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - relu_handler = UnaryElementwiseHandler(node=relu_mod_node, - device_mesh=device_mesh, - strategies_vector=relu_strategies_vector) + relu_handler = UnaryElementwiseHandler_V2(node=relu_mod_node, + device_mesh=device_mesh, + strategies_vector=relu_strategies_vector) relu_handler.register_strategy(compute_resharding_cost=False) diff --git a/tests/test_auto_parallel/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_solver_with_resnet_v2.py new file mode 100644 index 000000000..21dcbad56 --- /dev/null +++ b/tests/test_auto_parallel/test_solver_with_resnet_v2.py @@ -0,0 +1,99 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2 +from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2 +from copy import deepcopy +from colossalai.auto_parallel.solver.solver import Solver_V2 +from torchvision.models import resnet34, resnet50 +from colossalai.auto_parallel.solver.constants import * +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.solver.options import SolverOptions + + +@pytest.mark.skip("for higher testing speed") +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = resnet50(num_classes=100000) + input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) + # %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) + # %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) + # %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) + # %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) + # %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) + # %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) + # %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) + # %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) + # %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) + # %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) + # %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) + # %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) + # ... + # %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {}) + # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) + # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) + # return fc + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser) + + ret = solver.call_solver_serialized_args() + print(ret[0]) + print(solver.last_s_val) + strategies_list = solver.last_s_val + + computation_cost = 0 + communication_cost = 0 + communication_cost_bn = 0 + memory_cost = 0 + for index, node in enumerate(graph.nodes): + if node.op == 'call_module': + submod = node.graph.owning_module.get_submodule(node.target) + if type(submod) in BATCHNORM_MODULE_OP: + communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + print(f'bn communication cost is {communication_cost_bn}') + + +if __name__ == '__main__': + test_cost_graph()